Commit 36d6f41f authored by saberkun's avatar saberkun
Browse files

Merge pull request #10514 from ZihanWangKi:master

PiperOrigin-RevId: 436548878
parents bc8b6332 259c4347
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Common modules for callbacks.""" """Common modules for callbacks."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Runs an Image Classification model.""" """Runs an Image Classification model."""
import os import os
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Unit tests for the classifier trainer models.""" """Unit tests for the classifier trainer models."""
import functools import functools
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Unit tests for the classifier trainer models.""" """Unit tests for the classifier trainer models."""
from __future__ import absolute_import from __future__ import absolute_import
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Dataset utilities for vision tasks using TFDS and tf.data.Dataset.""" """Dataset utilities for vision tasks using TFDS and tf.data.Dataset."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Learning rate utilities for vision tasks.""" """Learning rate utilities for vision tasks."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Contains definitions for the AssembleNet [1] models. """Contains definitions for the AssembleNet [1] models.
Requires the AssembleNet architecture to be specified in Requires the AssembleNet architecture to be specified in
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Contains definitions for 'Representation Flow' layer [1]. """Contains definitions for 'Representation Flow' layer [1].
Representation flow layer is a generalization of optical flow extraction; the Representation flow layer is a generalization of optical flow extraction; the
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Detection input and model functions for serving/inference.""" """Detection input and model functions for serving/inference."""
from typing import Dict, Mapping, Text from typing import Dict, Mapping, Text
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Test for image detection export lib.""" """Test for image detection export lib."""
import io import io
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Tests for image classification task.""" """Tests for image classification task."""
# pylint: disable=unused-import # pylint: disable=unused-import
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Tests for semantic segmentation task.""" """Tests for semantic segmentation task."""
# pylint: disable=unused-import # pylint: disable=unused-import
......
# Longformer: The Long-Document Transformer
## Modifications from Huggingface's Implementation
All models require a `global_attention_size` specified in the config, setting a
global attention for all first `global_attention_size` tokens in any sentence.
Individual different global attention sizes for sentences are not supported.
This setting allows running on TPUs where tensor sizes have to be determined.
`_get_global_attn_indices` in `longformer_attention.py` contains how the new
global attention indices are specified. Changed all `tf.cond` to if
confiditions, since global attention is specified in the start now.
To load weights from a pre-trained huggingface longformer, run
`utils/convert_pretrained_pytorch_checkpoint_to_tf.py` to create a checkpoint. \
There is also a `utils/longformer_tokenizer_to_tfrecord.py` that transformers
pytorch longformer tokenized data to tf_records.
## Steps to Fine-tune on MNLI
#### Prepare the pre-trained checkpoint
Option 1. Use our saved checkpoint of `allenai/longformer-base-4096` stored in cloud storage
```bash
gsutil cp -r gs://model-garden-ucsd-zihan/longformer-4096 .
```
Option 2. Create it directly
```bash
python3 utils/convert_pretrained_pytorch_checkpoint_to_tf.py
```
#### [Optional] Prepare the input file
```bash
python3 longformer_tokenizer_to_tfrecord.py
```
#### Training
Here, we use the training data of MNLI that were uploaded to the cloud storage, you can replace it with the input files you generated.
```bash
TRAIN_DATA=task.train_data.input_path=gs://model-garden-ucsd-zihan/longformer_allenai_mnli_train.tf_record,task.validation_data.input_path=gs://model-garden-ucsd-zihan/longformer_allenai_mnli_eval.tf_record
INIT_CHECKPOINT=longformer-4096/longformer
PYTHONPATH=/path/to/model/garden \
python3 train.py \
--experiment=longformer/glue \
--config_file=experiments/glue_mnli_allenai.yaml \
--params_override="${TRAIN_DATA},runtime.distribution_strategy=tpu,task.init_checkpoint=${INIT_CHECKPOINT}" \
--tpu=local \
--model_dir=/path/to/outputdir \
--mode=train_and_eval
```
This should take ~ 3 hours to run, and give a performance of ~86.
task:
hub_module_url: ''
model:
num_classes: 3
encoder:
type: any
any:
max_position_embeddings: 512
attention_window: [32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32]
global_attention_size: 1
metric_type: 'accuracy'
train_data:
drop_remainder: true
global_batch_size: 32
input_path: TODO
is_training: true
seq_length: 128
validation_data:
drop_remainder: true
global_batch_size: 32
input_path: TODO
is_training: false
seq_length: 128
trainer:
checkpoint_interval: 1000
continuous_eval_timeout: 7200
optimizer_config:
learning_rate:
polynomial:
decay_steps: 61359
end_learning_rate: 0.0
initial_learning_rate: 3.0e-05
power: 1.0
type: polynomial
optimizer:
type: adamw
warmup:
polynomial:
power: 1
warmup_steps: 6136
type: polynomial
steps_per_loop: 100
summary_interval: 100
# Training data size 392,702 examples, 5 epochs.
train_steps: 61359
validation_interval: 2000
validation_steps: 307
task:
hub_module_url: ''
model:
num_classes: 3
encoder:
type: any
any:
max_position_embeddings: 4098
attention_window: [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128]
global_attention_size: 1
vocab_size: 50265
metric_type: 'accuracy'
train_data:
drop_remainder: true
global_batch_size: 32
input_path: TODO
is_training: true
seq_length: 512
validation_data:
drop_remainder: true
global_batch_size: 32
input_path: TODO
is_training: false
seq_length: 512
trainer:
checkpoint_interval: 1000
continuous_eval_timeout: 7200
optimizer_config:
learning_rate:
polynomial:
decay_steps: 61359
end_learning_rate: 0.0
initial_learning_rate: 3.0e-05
power: 1.0
type: polynomial
optimizer:
type: adamw
warmup:
polynomial:
power: 1
warmup_steps: 6136
type: polynomial
steps_per_loop: 1000
summary_interval: 1000
# Training data size 392,702 examples, 5 epochs.
train_steps: 61359
validation_interval: 2000
validation_steps: 307
task:
init_checkpoint: ""
model:
cls_heads:
[
{
activation: tanh,
cls_token_idx: 0,
dropout_rate: 0.1,
inner_dim: 768,
name: next_sentence,
num_classes: 2,
},
]
encoder:
type: any
any:
attention_dropout_rate: 0.1
dropout_rate: 0.1
embedding_size: 768
hidden_activation: gelu
hidden_size: 768
initializer_range: 0.02
intermediate_size: 3072
max_position_embeddings: 512
num_attention_heads: 12
num_layers: 12
type_vocab_size: 2
vocab_size: 30522
attention_window: [32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32]
global_attention_size: 1
train_data:
drop_remainder: true
global_batch_size: 256
input_path: gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00000-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00001-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00002-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00003-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00004-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00005-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00006-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00007-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00008-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00009-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00010-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00011-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00012-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00013-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00014-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00015-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00016-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00017-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00018-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00019-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00020-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00021-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00022-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00023-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00024-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00025-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00026-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00027-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00028-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00029-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00030-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00031-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00032-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00033-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00034-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00035-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00036-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00037-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00038-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00039-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00040-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00041-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00042-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00043-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00044-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00045-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00046-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00047-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00048-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00049-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00050-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00051-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00052-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00053-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00054-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00055-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00056-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00057-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00058-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00059-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00060-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00061-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00062-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00063-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00064-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00065-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00066-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00067-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00068-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00069-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00070-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00071-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00072-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00073-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00074-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00075-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00076-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00077-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00078-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00079-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00080-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00081-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00082-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00083-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00084-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00085-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00086-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00087-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00088-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00089-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00090-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00091-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00092-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00093-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00094-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00095-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00096-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00097-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00098-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00099-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00100-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00101-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00102-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00103-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00104-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00105-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00106-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00107-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00108-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00109-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00110-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00111-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00112-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00113-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00114-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00115-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00116-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00117-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00118-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00119-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00120-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00121-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00122-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00123-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00124-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00125-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00126-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00127-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00128-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00129-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00130-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00131-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00132-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00133-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00134-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00135-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00136-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00137-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00138-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00139-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00140-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00141-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00142-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00143-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00144-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00145-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00146-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00147-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00148-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00149-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00150-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00151-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00152-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00153-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00154-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00155-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00156-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00157-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00158-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00159-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00160-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00161-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00162-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00163-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00164-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00165-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00166-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00167-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00168-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00169-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00170-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00171-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00172-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00173-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00174-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00175-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00176-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00177-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00178-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00179-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00180-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00181-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00182-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00183-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00184-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00185-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00186-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00187-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00188-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00189-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00190-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00191-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00192-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00193-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00194-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00195-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00196-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00197-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00198-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00199-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00200-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00201-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00202-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00203-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00204-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00205-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00206-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00207-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00208-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00209-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00210-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00211-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00212-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00213-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00214-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00215-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00216-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00217-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00218-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00219-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00220-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00221-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00222-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00223-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00224-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00225-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00226-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00227-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00228-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00229-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00230-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00231-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00232-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00233-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00234-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00235-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00236-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00237-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00238-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00239-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00240-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00241-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00242-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00243-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00244-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00245-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00246-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00247-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00248-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00249-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00250-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00251-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00252-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00253-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00254-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00255-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00256-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00257-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00258-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00259-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00260-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00261-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00262-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00263-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00264-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00265-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00266-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00267-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00268-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00269-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00270-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00271-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00272-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00273-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00274-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00275-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00276-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00277-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00278-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00279-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00280-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00281-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00282-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00283-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00284-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00285-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00286-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00287-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00288-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00289-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00290-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00291-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00292-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00293-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00294-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00295-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00296-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00297-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00298-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00299-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00300-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00301-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00302-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00303-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00304-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00305-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00306-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00307-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00308-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00309-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00310-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00311-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00312-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00313-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00314-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00315-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00316-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00317-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00318-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00319-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00320-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00321-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00322-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00323-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00324-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00325-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00326-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00327-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00328-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00329-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00330-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00331-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00332-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00333-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00334-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00335-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00336-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00337-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00338-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00339-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00340-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00341-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00342-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00343-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00344-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00345-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00346-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00347-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00348-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00349-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00350-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00351-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00352-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00353-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00354-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00355-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00356-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00357-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00358-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00359-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00360-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00361-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00362-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00363-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00364-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00365-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00366-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00367-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00368-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00369-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00370-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00371-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00372-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00373-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00374-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00375-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00376-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00377-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00378-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00379-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00380-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00381-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00382-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00383-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00384-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00385-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00386-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00387-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00388-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00389-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00390-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00391-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00392-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00393-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00394-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00395-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00396-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00397-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00398-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00399-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00400-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00401-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00402-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00403-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00404-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00405-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00406-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00407-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00408-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00409-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00410-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00411-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00412-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00413-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00414-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00415-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00416-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00417-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00418-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00419-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00420-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00421-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00422-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00423-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00424-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00425-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00426-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00427-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00428-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00429-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00430-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00431-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00432-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00433-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00434-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00435-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00436-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00437-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00438-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00439-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00440-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00441-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00442-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00443-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00444-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00445-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00446-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00447-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00448-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00449-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00450-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00451-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00452-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00453-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00454-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00455-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00456-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00457-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00458-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00459-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00460-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00461-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00462-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00463-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00464-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00465-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00466-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00467-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00468-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00469-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00470-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00471-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00472-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00473-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00474-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00475-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00476-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00477-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00478-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00479-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00480-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00481-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00482-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00483-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00484-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00485-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00486-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00487-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00488-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00489-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00490-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00491-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00492-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00493-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00494-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00495-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00496-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00497-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00498-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00499-of-00500
is_training: true
max_predictions_per_seq: 76
seq_length: 512
use_next_sentence_label: true
use_position_id: false
validation_data:
drop_remainder: true
global_batch_size: 256
input_path: TODO
is_training: false
max_predictions_per_seq: 76
seq_length: 512
use_next_sentence_label: true
use_position_id: false
trainer:
checkpoint_interval: 20000
max_to_keep: 5
optimizer_config:
learning_rate:
polynomial:
cycle: false
decay_steps: 1000000
end_learning_rate: 0.0
initial_learning_rate: 0.0001
power: 1.0
type: polynomial
optimizer:
type: adamw
warmup:
polynomial:
power: 1
warmup_steps: 10000
type: polynomial
steps_per_loop: 50
summary_interval: 50
train_steps: 1000000
validation_interval: 1000
validation_steps: 64
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Longformer model configurations and instantiation methods."""
import dataclasses
from typing import List
import tensorflow as tf
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.nlp.configs import encoders
from official.projects.longformer.longformer_encoder import LongformerEncoder
@dataclasses.dataclass
class LongformerEncoderConfig(encoders.BertEncoderConfig):
"""Extra paramerters for Longformer configs.
Attributes:
attention_window: list of ints representing the window size for each layer.
global_attention_size: the size of global attention used for each token.
pad_token_id: the token id for the pad token
"""
attention_window: List[int] = dataclasses.field(default_factory=list)
global_attention_size: int = 0
pad_token_id: int = 1
@base_config.bind(LongformerEncoderConfig)
def get_encoder(encoder_cfg: LongformerEncoderConfig):
"""Gets a 'LongformerEncoder' object.
Args:
encoder_cfg: A 'LongformerEncoderConfig'.
Returns:
A encoder object.
"""
encoder = LongformerEncoder(
attention_window=encoder_cfg.attention_window,
global_attention_size=encoder_cfg.global_attention_size,
vocab_size=encoder_cfg.vocab_size,
hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers,
num_attention_heads=encoder_cfg.num_attention_heads,
inner_dim=encoder_cfg.intermediate_size,
inner_activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
output_dropout=encoder_cfg.dropout_rate,
attention_dropout=encoder_cfg.attention_dropout_rate,
max_sequence_length=encoder_cfg.max_position_embeddings,
type_vocab_size=encoder_cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
output_range=encoder_cfg.output_range,
embedding_width=encoder_cfg.embedding_size,
norm_first=encoder_cfg.norm_first)
return encoder
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Longformer attention block. Modified From huggingface/transformers."""
# pylint: disable=g-classes-have-attributes
import math
import string
import numpy as np
import tensorflow as tf
from official.modeling.tf_utils import get_shape_list
_CHR_IDX = string.ascii_lowercase
def _build_attention_equation(rank, attn_axes):
"""Builds einsum equations for the attention computation.
Query, key, value inputs after projection are expected to have the shape as:
`(bs, <non-attention dims>, <attention dims>, num_heads, channels)`.
`bs` and `<non-attention dims>` are treated as `<batch dims>`.
The attention operations can be generalized:
(1) Query-key dot product:
`(<batch dims>, <query attention dims>, num_heads, channels), (<batch dims>,
<key attention dims>, num_heads, channels) -> (<batch dims>,
num_heads, <query attention dims>, <key attention dims>)`
(2) Combination:
`(<batch dims>, num_heads, <query attention dims>, <key attention dims>),
(<batch dims>, <value attention dims>, num_heads, channels) -> (<batch dims>,
<query attention dims>, num_heads, channels)`
Args:
rank: Rank of query, key, value tensors.
attn_axes: List/tuple of axes, `[-1, rank)`, that attention will be applied
to.
Returns:
Einsum equations.
"""
target_notation = _CHR_IDX[:rank]
# `batch_dims` includes the head dim.
batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,)))
letter_offset = rank
source_notation = ""
for i in range(rank):
if i in batch_dims or i == rank - 1:
source_notation += target_notation[i]
else:
source_notation += _CHR_IDX[letter_offset]
letter_offset += 1
product_notation = "".join([target_notation[i] for i in batch_dims] +
[target_notation[i] for i in attn_axes] +
[source_notation[i] for i in attn_axes])
dot_product_equation = f"{source_notation},{target_notation}->{product_notation}"
attn_scores_rank = len(product_notation)
combine_equation = f"{product_notation},{source_notation}->{target_notation}"
return dot_product_equation, combine_equation, attn_scores_rank
def _build_proj_equation(free_dims, bound_dims, output_dims):
"""Builds an einsum equation for projections inside multi-head attention."""
input_str = ""
kernel_str = ""
output_str = ""
bias_axes = ""
letter_offset = 0
for i in range(free_dims):
char = _CHR_IDX[i + letter_offset]
input_str += char
output_str += char
letter_offset += free_dims
for i in range(bound_dims):
char = _CHR_IDX[i + letter_offset]
input_str += char
kernel_str += char
letter_offset += bound_dims
for i in range(output_dims):
char = _CHR_IDX[i + letter_offset]
kernel_str += char
output_str += char
bias_axes += char
equation = f"{input_str},{kernel_str}->{output_str}"
return equation, bias_axes, len(output_str)
def _get_output_shape(output_rank, known_last_dims):
return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims)
@tf.keras.utils.register_keras_serializable(package="Text")
class LongformerAttention(tf.keras.layers.MultiHeadAttention):
"""LongformerAttention.
Args:
attention_window: int representing the window size for attention.
layer_id: int of the id of the layer.
global_attention_size: the size of global attention used for each token.
"""
def __init__(self, attention_window, layer_id, global_attention_size,
**kwargs):
super().__init__(**kwargs)
self._layer_id = layer_id
self._attention_window = attention_window
assert (self._attention_window % 2 == 0), (
f"`attention_window` for layer {self._layer_id} has to be an even "
f"value. Given {self.attention_window}")
assert (self._attention_window > 0), (
f"`attention_window` for layer {self._layer_id} has to be positive. "
f"Given {self.attention_window}")
self._one_sided_attn_window_size = self._attention_window // 2
self.global_attention_size = global_attention_size
def _build_from_signature(self, query, value, key=None):
"""Builds layers and variables.
Once the method is called, self._built_from_signature will be set to True.
Args:
query: Query tensor or TensorShape.
value: Value tensor or TensorShape.
key: Key tensor or TensorShape.
"""
self._built_from_signature = True
if hasattr(query, "shape"):
self._query_shape = tf.TensorShape(query.shape)
else:
self._query_shape = tf.TensorShape(query)
if hasattr(value, "shape"):
self._value_shape = tf.TensorShape(value.shape)
else:
self._value_shape = tf.TensorShape(value)
if key is None:
self._key_shape = self._value_shape
elif hasattr(key, "shape"):
self._key_shape = tf.TensorShape(key.shape)
else:
self._key_shape = tf.TensorShape(key)
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
# Any setup work performed only once should happen in an `init_scope`
# to avoid creating symbolic Tensors that will later pollute any eager
# operations.
# with tf_utils.maybe_init_scope(self):
# TODO(crickwu): check whether tf_utils.maybe_init_scope(self) (keras)
# is needed.
free_dims = self._query_shape.rank - 1
einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=1, output_dims=2)
self._query_dense = tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="query",
**common_kwargs)
self._global_query_dense = tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="global_query",
**common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
self._key_shape.rank - 1, bound_dims=1, output_dims=2)
self._key_dense = tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="key",
**common_kwargs)
self._global_key_dense = tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="global_key",
**common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
self._value_shape.rank - 1, bound_dims=1, output_dims=2)
self._value_dense = tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._value_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="value",
**common_kwargs)
self._global_value_dense = tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._value_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="global_value",
**common_kwargs)
# Builds the attention computations for multi-head dot product attention.
# These computations could be wrapped into the keras attention layer once
# it support mult-head einsum computations.
self._build_attention(output_rank)
self._global_dropout_layer = tf.keras.layers.Dropout(rate=self._dropout)
# self._output_dense = self._make_output_dense(
# free_dims, common_kwargs, "attention_output")
self._output_dense = tf.keras.layers.Dense(
units=self._num_heads * self._key_dim, name="dense", **common_kwargs)
def call(self,
hidden_states,
attention_mask=None,
is_index_masked=None,
is_index_global_attn=None,
training=None):
"""Applies Dot-product attention with query, key, value tensors.
This function defines the computation inside `call` with projected
multi-head Q, K, V inputs. Users can override this function for customized
attention implementation.
Args:
hidden_states: inputs for generating query, key and value tensors.
attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
attention to certain positions.
is_index_masked: boolean indicating whether the index is masked.
is_index_global_attn: boolean indicating whether the index is global
attention.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
Returns:
attention_output: Multi-headed outputs of attention computation.
"""
if not self._built_from_signature:
self._build_from_signature(
query=hidden_states, value=hidden_states, key=hidden_states)
# N = `num_attention_heads`
# H = `size_per_head`
# `query` = [B, T, N ,H]
query = self._query_dense(hidden_states)
# `key` = [B, S, N, H]
key = self._key_dense(hidden_states)
# `value` = [B, S, N, H]
value = self._value_dense(hidden_states)
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_dim)))
batch_size, seq_len, num_heads, head_dim = get_shape_list(query)
# attn_probs = (batch_size, seq_len, num_heads, window*2+1)
attn_scores = self._sliding_chunks_query_key_matmul(
query, key, self._one_sided_attn_window_size)
# diagonal mask with zeros everywhere and -inf inplace of padding
diagonal_mask = self._sliding_chunks_query_key_matmul(
tf.ones(get_shape_list(attention_mask)),
attention_mask,
self._one_sided_attn_window_size,
)
# pad local attention probs
attn_scores += diagonal_mask
if tf.executing_eagerly():
tf.debugging.assert_equal(
get_shape_list(attn_scores),
[
batch_size, seq_len, self._num_heads,
self._one_sided_attn_window_size * 2 + 1
],
message=f"attn_probs should be of size "
f"({batch_size}, {seq_len}, {num_heads}, "
f"{self._one_sided_attn_window_size * 2 + 1}),"
f" but is of size {get_shape_list(attn_scores)}",
)
# compute global attn indices required through out forward fn
(
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
) = self._get_global_attn_indices(is_index_global_attn,
self.global_attention_size)
# this function is only relevant for global attention
if self.global_attention_size > 0:
attn_scores = self._concat_with_global_key_attn_probs(
attn_scores=attn_scores,
query_vectors=query,
key_vectors=key,
max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
)
else:
pass
attn_probs = tf.nn.softmax(attn_scores, axis=-1)
# softmax sometimes inserts NaN if all positions are masked,
# replace them with 0
# Make sure to create a mask with the proper shape:
# if is_global_attn==True => [batch_size, seq_len, self.num_heads,
# self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
# if is_global_attn==False => [batch_size, seq_len, self.num_heads,
# self.one_sided_attn_window_size * 2 + 1]
if self.global_attention_size > 0:
masked_index = tf.tile(
is_index_masked[:, :, None, None],
(1, 1, self._num_heads, self._one_sided_attn_window_size * 2 +
max_num_global_attn_indices + 1),
)
else:
masked_index = tf.tile(
is_index_masked[:, :, None, None],
(1, 1, self._num_heads, self._one_sided_attn_window_size * 2 + 1),
)
attn_probs = tf.where(
masked_index,
tf.zeros(get_shape_list(masked_index), dtype=attn_probs.dtype),
attn_probs,
)
layer_head_mask = None
if layer_head_mask is not None:
if tf.executing_eagerly():
tf.debugging.assert_equal(
get_shape_list(layer_head_mask),
[self._num_heads],
message=f"Head mask for a single layer should be of size "
f"{(self._num_heads)}, but is "
f"{get_shape_list(layer_head_mask)}",
)
attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
# apply dropout
attn_probs = self._dropout_layer(attn_probs, training=training)
value_vectors = tf.reshape(
value, (batch_size, seq_len, self._num_heads, self._key_dim))
# if global attention, compute sum of global and local attn
if self.global_attention_size > 0:
attn_output = self._compute_attn_output_with_global_indices(
value_vectors=value_vectors,
attn_probs=attn_probs,
max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
)
else:
attn_output = self._sliding_chunks_matmul_attn_probs_value(
attn_probs, value_vectors, self._one_sided_attn_window_size)
if tf.executing_eagerly():
tf.debugging.assert_equal(
get_shape_list(attn_output),
[batch_size, seq_len, self._num_heads, head_dim],
message="Unexpected size",
)
attn_output = tf.reshape(
attn_output,
(batch_size, seq_len, self._num_heads * self._key_dim)) # FIXME
# compute value for global attention and overwrite to attention output
# TODO(crickwu): remove the redundant computation
if self.global_attention_size > 0:
attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden( # pylint: disable=unused-variable
attn_output=attn_output,
hidden_states=hidden_states,
max_num_global_attn_indices=max_num_global_attn_indices,
layer_head_mask=layer_head_mask,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
is_index_masked=is_index_masked,
training=training,
)
else:
global_attn_probs = tf.zeros(
(batch_size, self._num_heads, max_num_global_attn_indices, seq_len))
# make sure that local attention probabilities are set to 0 for indices of
# global attn
if self.global_attention_size > 0:
masked_global_attn_index = tf.tile(
is_index_global_attn[:, :, None, None],
(1, 1, self._num_heads, self._one_sided_attn_window_size * 2 +
max_num_global_attn_indices + 1),
)
else:
masked_global_attn_index = tf.tile(
is_index_global_attn[:, :, None, None],
(1, 1, self._num_heads, self._one_sided_attn_window_size * 2 + 1),
)
attn_probs = tf.where(
masked_global_attn_index,
tf.zeros(
get_shape_list(masked_global_attn_index), dtype=attn_probs.dtype),
attn_probs,
)
# we can return extra information here
# (attn_output, attn_probs, global_attn_probs)
return attn_output
def get_config(self):
config = {
"layer_id": self._layer_id,
"attention_window": self._one_sided_attn_window_size,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def _sliding_chunks_query_key_matmul(self, query, key, window_overlap):
"""Matrix multiplication of query and key tensors.
This multiplication uses a sliding window attention pattern.
This implementation splits the input into overlapping chunks of size
2w (e.g. 512 for pretrained Longformer) with an overlap of size
window_overlap.
Args:
query: query tensor.
key: key tensor.
window_overlap: int.
Returns:
diagonal_attention_scores: tensor.
"""
batch_size, seq_len, num_heads, head_dim = get_shape_list(query)
if tf.executing_eagerly():
tf.debugging.assert_equal(
seq_len % (window_overlap * 2),
0,
message=f"Sequence length should be multiple of {window_overlap * 2}. "
f"Given {seq_len}",
)
tf.debugging.assert_equal(
get_shape_list(query),
get_shape_list(key),
message=f"Shape of query and key should be equal, but got query: "
f"{get_shape_list(query)} and key: {get_shape_list(key)}",
)
chunks_count = seq_len // window_overlap - 1
# group batch_size and num_heads dimensions into one,
# then chunk seq_len into chunks of size window_overlap * 2
query = tf.reshape(
tf.transpose(query, (0, 2, 1, 3)),
(batch_size * num_heads, seq_len, head_dim),
)
key = tf.reshape(
tf.transpose(key, (0, 2, 1, 3)),
(batch_size * num_heads, seq_len, head_dim))
chunked_query = self._chunk(query, window_overlap)
chunked_key = self._chunk(key, window_overlap)
# matrix multiplication
# bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap
chunked_query = tf.cast(chunked_query, dtype=chunked_key.dtype)
chunked_attention_scores = tf.einsum("bcxd,bcyd->bcxy", chunked_query,
chunked_key) # multiply
# convert diagonals into columns
paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 1], [0, 0]])
diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(
chunked_attention_scores, paddings)
# allocate space for the overall attention matrix where the chunks are
# combined. The last dimension
# has (window_overlap * 2 + 1) columns. The first (window_overlap) columns
# are the window_overlap lower triangles (attention from a word to
# window_overlap previous words). The following column is attention score
# from each word to itself, then
# followed by window_overlap columns for the upper triangle.
# copy parts from diagonal_chunked_attention_scores into the combined matrix
# of attentions - copying the main diagonal and the upper triangle
# TODO(crickwu): This code is most likely not very efficient and should be
# improved.
diagonal_attn_scores_up_triang = tf.concat(
[
diagonal_chunked_attention_scores[:, :, :window_overlap, :
window_overlap + 1],
diagonal_chunked_attention_scores[:, -1:,
window_overlap:, :window_overlap +
1],
],
axis=1,
)
# - copying the lower triangle
diagonal_attn_scores_low_triang = tf.concat(
[
tf.zeros(
(batch_size * num_heads, 1, window_overlap, window_overlap),
dtype=diagonal_chunked_attention_scores.dtype,
),
diagonal_chunked_attention_scores[:, :, -(window_overlap + 1):-1,
window_overlap + 1:],
],
axis=1,
)
diagonal_attn_scores_first_chunk = tf.concat(
[
tf.roll(
diagonal_chunked_attention_scores,
shift=[1, window_overlap],
axis=[2, 3],
)[:, :, :window_overlap, :window_overlap],
tf.zeros(
(batch_size * num_heads, 1, window_overlap, window_overlap),
dtype=diagonal_chunked_attention_scores.dtype,
),
],
axis=1,
)
first_chunk_mask = (
tf.tile(
tf.range(chunks_count + 1)[None, :, None, None],
(batch_size * num_heads, 1, window_overlap, window_overlap),
) < 1)
diagonal_attn_scores_low_triang = tf.where(
first_chunk_mask,
diagonal_attn_scores_first_chunk,
diagonal_attn_scores_low_triang,
)
# merging upper and lower triangle
diagonal_attention_scores = tf.concat(
[diagonal_attn_scores_low_triang, diagonal_attn_scores_up_triang],
axis=-1)
# separate batch_size and num_heads dimensions again
diagonal_attention_scores = tf.transpose(
tf.reshape(
diagonal_attention_scores,
(batch_size, num_heads, seq_len, 2 * window_overlap + 1),
),
(0, 2, 1, 3),
)
diagonal_attention_scores = self._mask_invalid_locations(
diagonal_attention_scores, window_overlap)
return diagonal_attention_scores
@staticmethod
def _mask_invalid_locations(input_tensor, window_overlap):
# create correct upper triangle bool mask
mask_2d_upper = tf.reverse(
tf.linalg.band_part(
tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0),
axis=[0],
)
# pad to full matrix
padding = tf.convert_to_tensor(
[[0, get_shape_list(input_tensor)[1] - window_overlap],
[0, get_shape_list(input_tensor)[3] - window_overlap - 1]])
# create lower mask
mask_2d = tf.pad(mask_2d_upper, padding)
# combine with upper mask
mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1])
# broadcast to full matrix
mask_4d = tf.tile(mask_2d[None, :, None, :],
(get_shape_list(input_tensor)[0], 1, 1, 1))
# inf tensor used for masking
inf_tensor = -float("inf") * tf.ones_like(input_tensor)
# mask
input_tensor = tf.where(
tf.math.greater(mask_4d, 0), inf_tensor, input_tensor)
return input_tensor
def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value,
window_overlap):
"""Same as _sliding_chunks_query_key_matmul but for attn_probs and value."""
batch_size, seq_len, num_heads, head_dim = get_shape_list(value)
if tf.executing_eagerly():
tf.debugging.assert_equal(
seq_len % (window_overlap * 2),
0,
message="Seq_len has to be multiple of 2 * window_overlap",
)
tf.debugging.assert_equal(
get_shape_list(attn_probs)[:3],
get_shape_list(value)[:3],
message="value and attn_probs must have same dims (except head_dim)",
)
tf.debugging.assert_equal(
get_shape_list(attn_probs)[3],
2 * window_overlap + 1,
message="attn_probs last dim has to be 2 * window_overlap + 1",
)
chunks_count = seq_len // window_overlap - 1
# group batch_size and num_heads dimensions into one, then chunk seq_len
# into chunks of size 2 window overlap
chunked_attn_probs = tf.reshape(
tf.transpose(attn_probs, (0, 2, 1, 3)),
(
batch_size * num_heads,
seq_len // window_overlap,
window_overlap,
2 * window_overlap + 1,
),
)
# group batch_size and num_heads dimensions into one
value = tf.reshape(
tf.transpose(value, (0, 2, 1, 3)),
(batch_size * num_heads, seq_len, head_dim),
)
# pad seq_len with w at the beginning of the sequence and another window
# overlap at the end
paddings = tf.convert_to_tensor([[0, 0], [window_overlap, window_overlap],
[0, 0]])
padded_value = tf.pad(value, paddings, constant_values=-1)
# chunk padded_value into chunks of size 3 window overlap and an overlap of
# size window overlap
frame_size = 3 * window_overlap * head_dim
frame_hop_size = (get_shape_list(padded_value)[1] * head_dim -
frame_size) // chunks_count
chunked_value = tf.signal.frame(
tf.reshape(padded_value, (batch_size * num_heads, -1)),
frame_size,
frame_hop_size,
)
chunked_value = tf.reshape(
chunked_value,
(batch_size * num_heads, chunks_count + 1, 3 * window_overlap,
head_dim),
)
if tf.executing_eagerly():
tf.debugging.assert_equal(
get_shape_list(chunked_value),
[
batch_size * num_heads, chunks_count + 1, 3 * window_overlap,
head_dim
],
message="Chunked value has the wrong shape",
)
chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)
context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value)
context = tf.transpose(
tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)),
(0, 2, 1, 3),
)
return context
@staticmethod
def _pad_and_transpose_last_two_dims(hidden_states_padded, paddings):
"""Pads rows and then flips rows and columns."""
hidden_states_padded = tf.pad(
hidden_states_padded, paddings
) # padding value is not important because it will be overwritten
batch_size, chunk_size, seq_length, hidden_dim = get_shape_list(
hidden_states_padded)
hidden_states_padded = tf.reshape(
hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length))
return hidden_states_padded
@staticmethod
def _pad_and_diagonalize(chunked_hidden_states):
"""Shifts every row 1 step right, converting columns into diagonals.
Example::
chunked_hidden_states: [ 0.4983, 2.6918, -0.0071, 1.0492,
-1.8348, 0.7672, 0.2986, 0.0285,
-0.7584, 0.4206, -0.0405, 0.1599,
2.0514, -1.1600, 0.5372, 0.2629 ]
window_overlap = num_rows = 4
(pad & diagonalize) =>
[ 0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000
0.0000, -1.8348, 0.7672, 0.2986, 0.0285, 0.0000, 0.0000
0.0000, 0.0000, -0.7584, 0.4206, -0.0405, 0.1599, 0.0000
0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ]
Args:
chunked_hidden_states: tensor.
Returns:
padded_hidden_stategs: tensor.
"""
total_num_heads, num_chunks, window_overlap, hidden_dim = get_shape_list(
chunked_hidden_states)
paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 0],
[0, window_overlap + 1]])
chunked_hidden_states = tf.pad(chunked_hidden_states, paddings)
chunked_hidden_states = tf.reshape(chunked_hidden_states,
(total_num_heads, num_chunks, -1))
chunked_hidden_states = chunked_hidden_states[:, :, :-window_overlap]
chunked_hidden_states = tf.reshape(
chunked_hidden_states,
(total_num_heads, num_chunks, window_overlap,
window_overlap + hidden_dim),
)
chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
return chunked_hidden_states
@staticmethod
def _chunk(hidden_states, window_overlap):
"""convert into overlapping chunks. Chunk size = 2w, overlap size = w."""
batch_size, seq_length, hidden_dim = get_shape_list(hidden_states)
num_output_chunks = 2 * (seq_length // (2 * window_overlap)) - 1
# define frame size and frame stride (similar to convolution)
frame_hop_size = window_overlap * hidden_dim
frame_size = 2 * frame_hop_size
hidden_states = tf.reshape(hidden_states,
(batch_size, seq_length * hidden_dim))
# chunk with overlap
chunked_hidden_states = tf.signal.frame(hidden_states, frame_size,
frame_hop_size)
if tf.executing_eagerly():
tf.debugging.assert_equal(
get_shape_list(chunked_hidden_states),
[batch_size, num_output_chunks, frame_size],
message=f"Make sure chunking is correctly applied. `Chunked hidden "
f"states should have output dimension"
f" {[batch_size, frame_size, num_output_chunks]}, but got "
f"{get_shape_list(chunked_hidden_states)}.",
)
chunked_hidden_states = tf.reshape(
chunked_hidden_states,
(batch_size, num_output_chunks, 2 * window_overlap, hidden_dim),
)
return chunked_hidden_states
@staticmethod
def _get_global_attn_indices(is_index_global_attn, global_attention_size):
"""Computes global attn indices required throughout forward pass."""
# All global attention size are fixed through global_attention_size
batch_size, _ = get_shape_list(is_index_global_attn)
max_num_global_attn_indices = global_attention_size
row_indices = tf.range(batch_size)
row_indices = tf.repeat(
tf.expand_dims(row_indices, axis=0),
repeats=[global_attention_size],
axis=0)
row_indices = tf.reshape(row_indices,
(batch_size * global_attention_size, 1))
col_indices = tf.range(global_attention_size)
col_indices = tf.repeat(
tf.expand_dims(col_indices, axis=1), repeats=[batch_size], axis=0)
is_index_global_attn_nonzero = tf.concat((row_indices, col_indices), axis=1)
# this is actually same as `is_index_global_attn_nonzero`,
# since we assume all global attention are the same size
is_local_index_global_attn_nonzero = tf.concat((row_indices, col_indices),
axis=1)
# empty tensor
is_local_index_no_global_attn_nonzero = tf.reshape(
tf.expand_dims(tf.range(0), axis=1), (0, 2))
return (
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
)
def _concat_with_global_key_attn_probs(
self,
attn_scores,
key_vectors,
query_vectors,
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
):
batch_size = get_shape_list(key_vectors)[0]
# select global key vectors
global_key_vectors = tf.gather_nd(key_vectors, is_index_global_attn_nonzero)
# create only global key vectors
key_vectors_only_global = tf.scatter_nd(
is_local_index_global_attn_nonzero,
global_key_vectors,
shape=(
batch_size,
max_num_global_attn_indices,
self._num_heads,
self._key_dim,
),
)
# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
attn_probs_from_global_key = tf.einsum("blhd,bshd->blhs", query_vectors,
key_vectors_only_global)
# (batch_size, max_num_global_attn_indices, seq_len, num_heads)
attn_probs_from_global_key_trans = tf.transpose(attn_probs_from_global_key,
(0, 3, 1, 2))
mask_shape = (
get_shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple(
get_shape_list(attn_probs_from_global_key_trans)[-2:])
mask = tf.ones(mask_shape) * -10000.0
mask = tf.cast(mask, dtype=attn_probs_from_global_key_trans.dtype)
# scatter mask
attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update(
attn_probs_from_global_key_trans,
is_local_index_no_global_attn_nonzero,
mask,
)
# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
attn_probs_from_global_key = tf.transpose(attn_probs_from_global_key_trans,
(0, 2, 3, 1))
# concat to attn_probs
# (batch_size, seq_len, num_heads, extra attention count + 2*window+1)
attn_scores = tf.concat((attn_probs_from_global_key, attn_scores), axis=-1)
return attn_scores
def _compute_attn_output_with_global_indices(
self,
value_vectors,
attn_probs,
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
):
batch_size = get_shape_list(attn_probs)[0]
# cut local attn probs to global only
attn_probs_only_global = attn_probs[:, :, :, :max_num_global_attn_indices]
# select global value vectors
global_value_vectors = tf.gather_nd(value_vectors,
is_index_global_attn_nonzero)
# create only global value vectors
value_vectors_only_global = tf.scatter_nd(
is_local_index_global_attn_nonzero,
global_value_vectors,
shape=(
batch_size,
max_num_global_attn_indices,
self._num_heads,
self._key_dim,
),
)
# compute attn output only global
attn_output_only_global = tf.einsum("blhs,bshd->blhd",
attn_probs_only_global,
value_vectors_only_global)
# reshape attn probs
attn_probs_without_global = attn_probs[:, :, :,
max_num_global_attn_indices:]
# compute attn output with global
attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value(
attn_probs_without_global, value_vectors,
self._one_sided_attn_window_size)
return attn_output_only_global + attn_output_without_global
def _compute_global_attn_output_from_hidden(
self,
attn_output,
hidden_states,
max_num_global_attn_indices,
layer_head_mask,
is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
is_index_masked,
training,
):
batch_size, seq_len = get_shape_list(hidden_states)[:2]
# prepare global hidden states
global_attn_hidden_states = tf.gather_nd(hidden_states,
is_index_global_attn_nonzero)
global_attn_hidden_states = tf.scatter_nd(
is_local_index_global_attn_nonzero,
global_attn_hidden_states,
shape=(batch_size, max_num_global_attn_indices,
self._num_heads * self._key_dim),
)
# global key, query, value
global_query_vectors_only_global = self._global_query_dense(
global_attn_hidden_states)
global_key_vectors = self._global_key_dense(hidden_states)
global_value_vectors = self._global_value_dense(hidden_states)
# normalize
global_query_vectors_only_global /= tf.math.sqrt(
tf.cast(self._key_dim, dtype=global_query_vectors_only_global.dtype))
global_query_vectors_only_global = self.reshape_and_transpose(
global_query_vectors_only_global, batch_size)
global_key_vectors = self.reshape_and_transpose(global_key_vectors,
batch_size)
global_value_vectors = self.reshape_and_transpose(global_value_vectors,
batch_size)
# compute attn scores
global_attn_scores = tf.matmul(
global_query_vectors_only_global, global_key_vectors, transpose_b=True)
if tf.executing_eagerly():
tf.debugging.assert_equal(
get_shape_list(global_attn_scores),
[batch_size * self._num_heads, max_num_global_attn_indices, seq_len],
message=f"global_attn_scores have the wrong size. Size should be"
f"{(batch_size * self._num_heads, max_num_global_attn_indices, seq_len)}, "
f"but is {get_shape_list(global_attn_scores)}.",
)
global_attn_scores = tf.reshape(
global_attn_scores,
(batch_size, self._num_heads, max_num_global_attn_indices, seq_len),
)
global_attn_scores_trans = tf.transpose(global_attn_scores, (0, 2, 1, 3))
mask_shape = (get_shape_list(is_local_index_no_global_attn_nonzero)[0],
) + tuple(get_shape_list(global_attn_scores_trans)[-2:])
global_attn_mask = tf.ones(mask_shape) * -10000.0
global_attn_mask = tf.cast(
global_attn_mask, dtype=global_attn_scores_trans.dtype)
# scatter mask
global_attn_scores_trans = tf.tensor_scatter_nd_update(
global_attn_scores_trans,
is_local_index_no_global_attn_nonzero,
global_attn_mask,
)
global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3))
# mask global attn scores
attn_mask = tf.tile(is_index_masked[:, None, None, :],
(1, get_shape_list(global_attn_scores)[1], 1, 1))
global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores)
global_attn_scores = tf.reshape(
global_attn_scores,
(batch_size * self._num_heads, max_num_global_attn_indices, seq_len),
)
# compute global attn probs
global_attn_probs_float = tf.nn.softmax(global_attn_scores, axis=-1)
# apply layer head masking
if layer_head_mask is not None:
if tf.executing_eagerly():
tf.debugging.assert_equal(
get_shape_list(layer_head_mask),
[self._num_heads],
message=f"Head mask for a single layer should be of size "
f"{(self._num_heads)}, but is {get_shape_list(layer_head_mask)}",
)
global_attn_probs_float = tf.reshape(
layer_head_mask,
(1, -1, 1, 1)) * tf.reshape(global_attn_probs_float,
(batch_size, self._num_heads,
max_num_global_attn_indices, seq_len))
global_attn_probs_float = tf.reshape(
global_attn_probs_float,
(batch_size * self._num_heads, max_num_global_attn_indices, seq_len))
# dropout
global_attn_probs = self._global_dropout_layer(
global_attn_probs_float, training=training)
# global attn output
global_attn_output = tf.matmul(global_attn_probs, global_value_vectors)
if tf.executing_eagerly():
tf.debugging.assert_equal(
get_shape_list(global_attn_output),
[
batch_size * self._num_heads, max_num_global_attn_indices,
self._key_dim
],
message=f"global_attn_output tensor has the wrong size. Size should be "
f"{(batch_size * self._num_heads, max_num_global_attn_indices, self._key_dim)}, "
f"but is {get_shape_list(global_attn_output)}.",
)
global_attn_output = tf.reshape(
global_attn_output,
(batch_size, self._num_heads, max_num_global_attn_indices,
self._key_dim),
)
# get only non zero global attn output
nonzero_global_attn_output = tf.gather_nd(
tf.transpose(global_attn_output, (0, 2, 1, 3)),
is_local_index_global_attn_nonzero,
)
nonzero_global_attn_output = tf.reshape(
nonzero_global_attn_output,
(get_shape_list(is_local_index_global_attn_nonzero)[0], -1),
)
# overwrite values with global attention
attn_output = tf.tensor_scatter_nd_update(attn_output,
is_index_global_attn_nonzero,
nonzero_global_attn_output)
global_attn_probs = tf.reshape(
global_attn_probs,
(batch_size, self._num_heads, max_num_global_attn_indices, seq_len))
attn_output = self._output_dense(attn_output)
return attn_output, global_attn_probs
def reshape_and_transpose(self, vector, batch_size):
return tf.reshape(
tf.transpose(
tf.reshape(vector,
(batch_size, -1, self._num_heads, self._key_dim)),
(0, 2, 1, 3),
),
(batch_size * self._num_heads, -1, self._key_dim),
)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.projects.longformer.longformer_attention."""
import numpy as np
import tensorflow as tf
from official.modeling.tf_utils import get_shape_list
from official.projects.longformer import longformer_attention
def _create_mock_attention_data(num_heads,
key_dim,
value_dim,
q_seq_length,
kv_seq_length,
batch_size,
include_mask=False):
"""Creates mock testing data.
Args:
num_heads: `int`, Number of attention heads.
key_dim: `int`, Size of query head.
value_dim: `int`, Size of key, value dim.
q_seq_length: `int`, query sequence length of the input.
kv_seq_length: `int`, key, value sequence length of the input.
batch_size: `int`, the batch size.
include_mask: optional `bool`, whether or not to include mask data.
Returns:
A dictionary with `str` as keys and `Tensor` as values.
"""
query_shape = (batch_size, q_seq_length, key_dim)
value_shape = (batch_size, kv_seq_length, value_dim)
data = dict(
query=tf.random.normal(shape=query_shape),
value=tf.random.normal(shape=value_shape),
key=tf.random.normal(shape=value_shape))
total_seq_length = kv_seq_length
if include_mask:
mask_shape = (batch_size, num_heads, q_seq_length, total_seq_length)
mask_data = np.random.randint(2, size=mask_shape).astype('float32')
mask_data = dict(attention_mask=mask_data)
data.update(mask_data)
return data
class LongformerAttentionTest(tf.test.TestCase):
def setUp(self):
super(LongformerAttentionTest, self).setUp()
np.random.seed(0)
tf.random.set_seed(0)
def _get_hidden_states(self):
return tf.convert_to_tensor(
[[
[
4.98332758e-01,
2.69175139e00,
-7.08081422e-03,
1.04915401e00,
-1.83476661e00,
7.67220476e-01,
2.98580543e-01,
2.84803992e-02,
],
[
-7.58357372e-01,
4.20635998e-01,
-4.04739919e-02,
1.59924145e-01,
2.05135748e00,
-1.15997978e00,
5.37166397e-01,
2.62873606e-01,
],
[
-1.69438001e00,
4.17574660e-01,
-1.49196962e00,
-1.76483717e00,
-1.94566312e-01,
-1.71183858e00,
7.72903565e-01,
-1.11557056e00,
],
[
5.44028163e-01,
2.05466114e-01,
-3.63045868e-01,
2.41865062e-01,
3.20348382e-01,
-9.05611176e-01,
-1.92690727e-01,
-1.19917547e00,
],
]],
dtype=tf.float32,
)
def test_diagonalize(self):
hidden_states = self._get_hidden_states()
hidden_states = tf.reshape(hidden_states,
(1, 8, 4)) # set seq length = 8, hidden dim = 4
chunked_hidden_states = longformer_attention.LongformerAttention._chunk(
hidden_states, window_overlap=2)
window_overlap_size = get_shape_list(chunked_hidden_states)[2]
self.assertEqual(window_overlap_size, 4)
padded_hidden_states = longformer_attention.LongformerAttention._pad_and_diagonalize(
chunked_hidden_states)
self.assertEqual(
get_shape_list(padded_hidden_states)[-1],
get_shape_list(chunked_hidden_states)[-1] + window_overlap_size - 1)
# first row => [0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000]
tf.debugging.assert_near(
padded_hidden_states[0, 0, 0, :4],
chunked_hidden_states[0, 0, 0],
rtol=1e-3)
tf.debugging.assert_near(
padded_hidden_states[0, 0, 0, 4:],
tf.zeros((3,), dtype=tf.dtypes.float32),
rtol=1e-3)
# last row => [0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629]
tf.debugging.assert_near(
padded_hidden_states[0, 0, -1, 3:],
chunked_hidden_states[0, 0, -1],
rtol=1e-3)
tf.debugging.assert_near(
padded_hidden_states[0, 0, -1, :3],
tf.zeros((3,), dtype=tf.dtypes.float32),
rtol=1e-3)
def test_pad_and_transpose_last_two_dims(self):
hidden_states = self._get_hidden_states()
self.assertTrue(get_shape_list(hidden_states), [1, 8, 4])
# pad along seq length dim
paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]],
dtype=tf.dtypes.int32)
hidden_states = longformer_attention.LongformerAttention._chunk(
hidden_states, window_overlap=2)
padded_hidden_states = longformer_attention.LongformerAttention._pad_and_transpose_last_two_dims(
hidden_states, paddings)
self.assertEqual(get_shape_list(padded_hidden_states), [1, 1, 8, 5])
expected_added_dim = tf.zeros((5,), dtype=tf.dtypes.float32)
tf.debugging.assert_near(
expected_added_dim, padded_hidden_states[0, 0, -1, :], rtol=1e-6)
tf.debugging.assert_near(
hidden_states[0, 0, -1, :],
tf.reshape(padded_hidden_states, (1, -1))[0, 24:32],
rtol=1e-6)
def test_mask_invalid_locations(self):
hidden_states = self._get_hidden_states()
batch_size = 1
seq_length = 8
hidden_size = 4
hidden_states = tf.reshape(hidden_states,
(batch_size, seq_length, hidden_size))
hidden_states = longformer_attention.LongformerAttention._chunk(
hidden_states, window_overlap=2)
hid_states_1 = longformer_attention.LongformerAttention._mask_invalid_locations(
hidden_states, 1)
hid_states_2 = longformer_attention.LongformerAttention._mask_invalid_locations(
hidden_states, 2)
hid_states_3 = longformer_attention.LongformerAttention._mask_invalid_locations(
hidden_states[:, :, :, :3], 2)
hid_states_4 = longformer_attention.LongformerAttention._mask_invalid_locations(
hidden_states[:, :, 2:, :], 2)
self.assertEqual(
tf.math.reduce_sum(
tf.cast(tf.math.is_inf(hid_states_1), tf.dtypes.int32)), 8)
self.assertEqual(
tf.math.reduce_sum(
tf.cast(tf.math.is_inf(hid_states_2), tf.dtypes.int32)), 24)
self.assertEqual(
tf.math.reduce_sum(
tf.cast(tf.math.is_inf(hid_states_3), tf.dtypes.int32)), 24)
self.assertEqual(
tf.math.reduce_sum(
tf.cast(tf.math.is_inf(hid_states_4), tf.dtypes.int32)), 12)
def test_chunk(self):
hidden_states = self._get_hidden_states()
batch_size = 1
seq_length = 8
hidden_size = 4
hidden_states = tf.reshape(hidden_states,
(batch_size, seq_length, hidden_size))
chunked_hidden_states = longformer_attention.LongformerAttention._chunk(
hidden_states, window_overlap=2)
# expected slices across chunk and seq length dim
expected_slice_along_seq_length = tf.convert_to_tensor(
[0.4983, -0.7584, -1.6944], dtype=tf.dtypes.float32)
expected_slice_along_chunk = tf.convert_to_tensor(
[0.4983, -1.8348, -0.7584, 2.0514], dtype=tf.dtypes.float32)
self.assertEqual(get_shape_list(chunked_hidden_states), [1, 3, 4, 4])
tf.debugging.assert_near(
chunked_hidden_states[0, :, 0, 0],
expected_slice_along_seq_length,
rtol=1e-3)
tf.debugging.assert_near(
chunked_hidden_states[0, 0, :, 0],
expected_slice_along_chunk,
rtol=1e-3)
def test_layer_local_attn(self):
hidden_states = self._get_hidden_states()
batch_size, seq_length, _ = hidden_states.shape
layer = longformer_attention.LongformerAttention(
num_heads=2,
key_dim=4,
value_dim=4,
layer_id=0,
attention_window=4,
global_attention_size=0,
)
attention_mask = tf.zeros((batch_size, seq_length), dtype=tf.dtypes.float32)
is_index_global_attn = tf.math.greater(attention_mask, 1)
attention_mask = tf.where(
tf.range(4)[None, :, None, None] > 1, -10000.0,
attention_mask[:, :, None, None])
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
output_hidden_states = layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
)[0]
self.assertTrue(output_hidden_states.shape, (1, 4, 8))
def test_layer_global_attn(self):
layer = longformer_attention.LongformerAttention(
num_heads=2,
key_dim=4,
value_dim=4,
layer_id=0,
attention_window=4,
global_attention_size=1,
)
hidden_states = self._get_hidden_states()
hidden_states = tf.concat(
[self._get_hidden_states(),
self._get_hidden_states() - 0.5], axis=0)
_, seq_length, _ = hidden_states.shape
# create attn mask
attention_mask_1 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask_2 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask_1 = tf.where(
tf.range(4)[None, :, None, None] == 0, 10000.0, attention_mask_1)
attention_mask_1 = tf.where(
tf.range(4)[None, :, None, None] > 2, -10000.0, attention_mask_1)
attention_mask_2 = tf.where(
tf.range(4)[None, :, None, None] == 0, 10000.0, attention_mask_2)
attention_mask = tf.concat([attention_mask_1, attention_mask_2], axis=0)
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0)
output_hidden_states = layer(
hidden_states=hidden_states,
attention_mask=-tf.math.abs(attention_mask),
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
)[0]
self.assertTrue(output_hidden_states.shape, (2, 4, 8))
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Longformer encoder. Modified From huggingface/transformers."""
# pylint: disable=g-classes-have-attributes
from typing import Any, Callable, List, Optional, Union
from absl import logging
import tensorflow as tf
from official.modeling.tf_utils import get_shape_list
from official.nlp.modeling import layers
from official.projects.longformer.longformer_encoder_block import LongformerEncoderBlock
_Initializer = Union[str, tf.keras.initializers.Initializer]
_approx_gelu = lambda x: tf.keras.activations.gelu(x, approximate=True)
class LongformerEncoder(tf.keras.layers.Layer):
"""LongformerEncoder.
Args:
vocab_size: The size of the token vocabulary.
attention_window: list of ints representing the window size for each layer.
global_attention_size: the size of global attention used for each token.
pad_token_id: the token id for the pad token
hidden_size: The size of the transformer hidden layers.
num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads.
max_sequence_length: The maximum sequence length that this encoder can
consume. If None, max_sequence_length uses the value from sequence length.
This determines the variable shape for positional embeddings.
type_vocab_size: The number of types that the 'type_ids' input can take.
inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network for each transformer.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network for each transformer.
output_dropout: Dropout probability for the post-attention and output
dropout.
attention_dropout: The dropout rate to use for the attention layers within
the transformer layers.
initializer: The initialzer to use for all weights in this encoder.
output_range: The sequence output range, [0, output_range), by slicing the
target sequence of the last transformer layer. `None` means the entire
target sequence will attend to the source sequence, which yields the full
output.
embedding_width: The width of the word embeddings. If the embedding width is
not equal to hidden size, embedding parameters will be factorized into two
matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size').
embedding_layer: An optional Layer instance which will be called to generate
embeddings for the input word IDs.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
"""
def __init__(
self,
vocab_size: int,
attention_window: Union[List[int], int] = 512,
global_attention_size: int = 0,
pad_token_id: int = 1,
hidden_size: int = 768,
num_layers: int = 12,
num_attention_heads: int = 12,
max_sequence_length: int = 512,
type_vocab_size: int = 16,
inner_dim: int = 3072,
inner_activation: Callable[..., Any] = _approx_gelu,
output_dropout: float = 0.1,
attention_dropout: float = 0.1,
initializer: _Initializer = tf.keras.initializers.TruncatedNormal(
stddev=0.02),
output_range: Optional[int] = None,
embedding_width: Optional[int] = None,
embedding_layer: Optional[tf.keras.layers.Layer] = None,
norm_first: bool = False,
**kwargs):
super().__init__(**kwargs)
# Longformer args
self._attention_window = attention_window
self._global_attention_size = global_attention_size
self._pad_token_id = pad_token_id
activation = tf.keras.activations.get(inner_activation)
initializer = tf.keras.initializers.get(initializer)
if embedding_width is None:
embedding_width = hidden_size
if embedding_layer is None:
self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=initializer,
name='word_embeddings')
else:
self._embedding_layer = embedding_layer
self._position_embedding_layer = layers.PositionEmbedding(
initializer=initializer,
max_length=max_sequence_length,
name='position_embedding')
self._type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=initializer,
use_one_hot=True,
name='type_embeddings')
self._embedding_norm_layer = tf.keras.layers.LayerNormalization(
name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32)
self._embedding_dropout = tf.keras.layers.Dropout(
rate=output_dropout, name='embedding_dropout')
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
self._embedding_projection = None
if embedding_width != hidden_size:
self._embedding_projection = tf.keras.layers.experimental.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
kernel_initializer=initializer,
name='embedding_projection')
self._transformer_layers = []
self._attention_mask_layer = layers.SelfAttentionMask(
name='self_attention_mask')
for i in range(num_layers):
layer = LongformerEncoderBlock(
global_attention_size=global_attention_size,
num_attention_heads=num_attention_heads,
inner_dim=inner_dim,
inner_activation=inner_activation,
attention_window=attention_window[i],
layer_id=i,
output_dropout=output_dropout,
attention_dropout=attention_dropout,
norm_first=norm_first,
output_range=output_range if i == num_layers - 1 else None,
kernel_initializer=initializer,
name=f'transformer/layer_{i}')
self._transformer_layers.append(layer)
self._pooler_layer = tf.keras.layers.Dense(
units=hidden_size,
activation='tanh',
kernel_initializer=initializer,
name='pooler_transform')
self._config = {
'vocab_size': vocab_size,
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_attention_heads': num_attention_heads,
'max_sequence_length': max_sequence_length,
'type_vocab_size': type_vocab_size,
'inner_dim': inner_dim,
'inner_activation': tf.keras.activations.serialize(activation),
'output_dropout': output_dropout,
'attention_dropout': attention_dropout,
'initializer': tf.keras.initializers.serialize(initializer),
'output_range': output_range,
'embedding_width': embedding_width,
'embedding_layer': embedding_layer,
'norm_first': norm_first,
'attention_window': attention_window,
'global_attention_size': global_attention_size,
'pad_token_id': pad_token_id,
}
self.inputs = dict(
input_word_ids=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_mask=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_type_ids=tf.keras.Input(shape=(None,), dtype=tf.int32))
def call(self, inputs):
word_embeddings = None
if isinstance(inputs, dict):
word_ids = inputs.get('input_word_ids') # input_ids
mask = inputs.get('input_mask') # attention_mask
type_ids = inputs.get('input_type_ids') # token_type_ids
word_embeddings = inputs.get('input_word_embeddings',
None) # input_embeds
else:
raise ValueError(f'Unexpected inputs type to {self.__class__}.')
(
padding_len,
word_ids,
mask,
type_ids,
word_embeddings,
) = self._pad_to_window_size(
word_ids=word_ids,
mask=mask,
type_ids=type_ids,
word_embeddings=word_embeddings,
pad_token_id=self._pad_token_id)
if word_embeddings is None:
word_embeddings = self._embedding_layer(word_ids)
# absolute position embeddings.
position_embeddings = self._position_embedding_layer(word_embeddings)
type_embeddings = self._type_embedding_layer(type_ids)
embeddings = word_embeddings + position_embeddings + type_embeddings
embeddings = self._embedding_norm_layer(embeddings)
embeddings = self._embedding_dropout(embeddings)
if self._embedding_projection is not None:
embeddings = self._embedding_projection(embeddings)
batch_size, seq_len = get_shape_list(mask)
# create masks with fixed len global_attention_size
mask = tf.transpose(
tf.concat(
values=[
tf.ones(
(self._global_attention_size, batch_size), tf.int32) * 2,
tf.transpose(mask)[self._global_attention_size:]
],
axis=0))
is_index_masked = tf.math.less(mask, 1)
is_index_global_attn = tf.transpose(
tf.concat(
values=[
tf.ones((self._global_attention_size, batch_size), tf.bool),
tf.zeros((seq_len - self._global_attention_size, batch_size),
tf.bool)
],
axis=0))
# Longformer
attention_mask = mask
extended_attention_mask = tf.reshape(
attention_mask, (tf.shape(mask)[0], tf.shape(mask)[1], 1, 1))
attention_mask = tf.cast(
tf.math.abs(1 - extended_attention_mask), tf.dtypes.float32) * -10000.0
encoder_outputs = []
x = embeddings
# TFLongformerEncoder
for layer in self._transformer_layers:
x = layer([x, attention_mask, is_index_masked, is_index_global_attn])
encoder_outputs.append(x)
last_encoder_output = encoder_outputs[-1]
if padding_len > 0:
last_encoder_output = last_encoder_output[:, :-padding_len]
first_token_tensor = last_encoder_output[:, 0, :]
pooled_output = self._pooler_layer(first_token_tensor)
return dict(
sequence_output=last_encoder_output,
pooled_output=pooled_output,
encoder_outputs=encoder_outputs)
def get_embedding_table(self):
return self._embedding_layer.embeddings
def get_embedding_layer(self):
return self._embedding_layer
def get_config(self):
return dict(self._config)
@property
def transformer_layers(self):
"""List of Transformer layers in the encoder."""
return self._transformer_layers
@property
def pooler_layer(self):
"""The pooler dense layer after the transformer layers."""
return self._pooler_layer
@classmethod
def from_config(cls, config, custom_objects=None):
if 'embedding_layer' in config and config['embedding_layer'] is not None:
warn_string = (
'You are reloading a model that was saved with a '
'potentially-shared embedding layer object. If you contine to '
'train this model, the embedding layer will no longer be shared. '
'To work around this, load the model outside of the Keras API.')
print('WARNING: ' + warn_string)
logging.warn(warn_string)
return cls(**config)
def _pad_to_window_size(
self,
word_ids,
mask,
type_ids,
word_embeddings,
pad_token_id,
):
# padding
attention_window = max(self._attention_window)
assert (attention_window %
2 == 0), ('`attention_window` should be an even value.'
f'Given {attention_window}')
input_shape = get_shape_list(
word_ids) if word_ids is not None else get_shape_list(word_embeddings)
batch_size, seq_len = input_shape[:2]
if seq_len is not None:
padding_len = (attention_window -
seq_len % attention_window) % attention_window
else:
padding_len = 0
paddings = tf.convert_to_tensor([[0, 0], [0, padding_len]])
if word_ids is not None:
word_ids = tf.pad(word_ids, paddings, constant_values=pad_token_id)
if word_embeddings is not None:
def pad_embeddings():
word_ids_padding = tf.fill((batch_size, padding_len), self.pad_token_id)
word_embeddings_padding = self._embedding_layer(word_ids_padding)
return tf.concat([word_embeddings, word_embeddings_padding], axis=-2)
word_embeddings = tf.cond(
tf.math.greater(padding_len, 0), pad_embeddings,
lambda: word_embeddings)
mask = tf.pad(
mask, paddings,
constant_values=False) # no attention on the padding tokens
token_type_ids = tf.pad(
type_ids, paddings, constant_values=0) # pad with token_type_id = 0
return (
padding_len,
word_ids,
mask,
token_type_ids,
word_embeddings,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment