Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8b641b13
Unverified
Commit
8b641b13
authored
Mar 26, 2022
by
Srihari Humbarwadi
Committed by
GitHub
Mar 26, 2022
Browse files
Merge branch 'tensorflow:master' into panoptic-deeplab
parents
7cffacfe
357fa547
Changes
411
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2870 additions
and
15 deletions
+2870
-15
official/projects/longformer/experiments/glue_mnli.yaml
official/projects/longformer/experiments/glue_mnli.yaml
+47
-0
official/projects/longformer/experiments/glue_mnli_allenai.yaml
...al/projects/longformer/experiments/glue_mnli_allenai.yaml
+48
-0
official/projects/longformer/experiments/pretraining_512.yaml
...cial/projects/longformer/experiments/pretraining_512.yaml
+74
-0
official/projects/longformer/longformer.py
official/projects/longformer/longformer.py
+69
-0
official/projects/longformer/longformer_attention.py
official/projects/longformer/longformer_attention.py
+1082
-0
official/projects/longformer/longformer_attention_test.py
official/projects/longformer/longformer_attention_test.py
+306
-0
official/projects/longformer/longformer_encoder.py
official/projects/longformer/longformer_encoder.py
+365
-0
official/projects/longformer/longformer_encoder_block.py
official/projects/longformer/longformer_encoder_block.py
+340
-0
official/projects/longformer/longformer_encoder_test.py
official/projects/longformer/longformer_encoder_test.py
+97
-0
official/projects/longformer/longformer_experiments.py
official/projects/longformer/longformer_experiments.py
+123
-0
official/projects/longformer/train.py
official/projects/longformer/train.py
+6
-7
official/projects/longformer/utils/convert_pretrained_pytorch_checkpoint_to_tf.py
...rmer/utils/convert_pretrained_pytorch_checkpoint_to_tf.py
+200
-0
official/projects/longformer/utils/longformer_tokenizer_to_tfrecord.py
...ects/longformer/utils/longformer_tokenizer_to_tfrecord.py
+112
-0
official/projects/movinet/README.md
official/projects/movinet/README.md
+1
-2
official/projects/movinet/modeling/movinet.py
official/projects/movinet/modeling/movinet.py
+0
-1
official/projects/movinet/modeling/movinet_layers.py
official/projects/movinet/modeling/movinet_layers.py
+0
-1
official/projects/movinet/modeling/movinet_layers_test.py
official/projects/movinet/modeling/movinet_layers_test.py
+0
-1
official/projects/movinet/modeling/movinet_model_test.py
official/projects/movinet/modeling/movinet_model_test.py
+0
-1
official/projects/movinet/modeling/movinet_test.py
official/projects/movinet/modeling/movinet_test.py
+0
-1
official/projects/movinet/tools/export_saved_model.py
official/projects/movinet/tools/export_saved_model.py
+0
-1
No files found.
Too many changes to show.
To preserve performance only
411 of 411+
files are displayed.
Plain diff
Email patch
official/projects/longformer/experiments/glue_mnli.yaml
0 → 100644
View file @
8b641b13
task
:
hub_module_url
:
'
'
model
:
num_classes
:
3
encoder
:
type
:
any
any
:
max_position_embeddings
:
512
attention_window
:
[
32
,
32
,
32
,
32
,
32
,
32
,
32
,
32
,
32
,
32
,
32
,
32
]
global_attention_size
:
1
metric_type
:
'
accuracy'
train_data
:
drop_remainder
:
true
global_batch_size
:
32
input_path
:
TODO
is_training
:
true
seq_length
:
128
validation_data
:
drop_remainder
:
true
global_batch_size
:
32
input_path
:
TODO
is_training
:
false
seq_length
:
128
trainer
:
checkpoint_interval
:
1000
continuous_eval_timeout
:
7200
optimizer_config
:
learning_rate
:
polynomial
:
decay_steps
:
61359
end_learning_rate
:
0.0
initial_learning_rate
:
3.0e-05
power
:
1.0
type
:
polynomial
optimizer
:
type
:
adamw
warmup
:
polynomial
:
power
:
1
warmup_steps
:
6136
type
:
polynomial
steps_per_loop
:
100
summary_interval
:
100
# Training data size 392,702 examples, 5 epochs.
train_steps
:
61359
validation_interval
:
2000
validation_steps
:
307
official/projects/longformer/experiments/glue_mnli_allenai.yaml
0 → 100644
View file @
8b641b13
task
:
hub_module_url
:
'
'
model
:
num_classes
:
3
encoder
:
type
:
any
any
:
max_position_embeddings
:
4098
attention_window
:
[
128
,
128
,
128
,
128
,
128
,
128
,
128
,
128
,
128
,
128
,
128
,
128
]
global_attention_size
:
1
vocab_size
:
50265
metric_type
:
'
accuracy'
train_data
:
drop_remainder
:
true
global_batch_size
:
32
input_path
:
TODO
is_training
:
true
seq_length
:
512
validation_data
:
drop_remainder
:
true
global_batch_size
:
32
input_path
:
TODO
is_training
:
false
seq_length
:
512
trainer
:
checkpoint_interval
:
1000
continuous_eval_timeout
:
7200
optimizer_config
:
learning_rate
:
polynomial
:
decay_steps
:
61359
end_learning_rate
:
0.0
initial_learning_rate
:
3.0e-05
power
:
1.0
type
:
polynomial
optimizer
:
type
:
adamw
warmup
:
polynomial
:
power
:
1
warmup_steps
:
6136
type
:
polynomial
steps_per_loop
:
1000
summary_interval
:
1000
# Training data size 392,702 examples, 5 epochs.
train_steps
:
61359
validation_interval
:
2000
validation_steps
:
307
official/projects/longformer/experiments/pretraining_512.yaml
0 → 100644
View file @
8b641b13
task
:
init_checkpoint
:
"
"
model
:
cls_heads
:
[
{
activation
:
tanh
,
cls_token_idx
:
0
,
dropout_rate
:
0.1
,
inner_dim
:
768
,
name
:
next_sentence
,
num_classes
:
2
,
},
]
encoder
:
type
:
any
any
:
attention_dropout_rate
:
0.1
dropout_rate
:
0.1
embedding_size
:
768
hidden_activation
:
gelu
hidden_size
:
768
initializer_range
:
0.02
intermediate_size
:
3072
max_position_embeddings
:
512
num_attention_heads
:
12
num_layers
:
12
type_vocab_size
:
2
vocab_size
:
30522
attention_window
:
[
32
,
32
,
32
,
32
,
32
,
32
,
32
,
32
,
32
,
32
,
32
,
32
]
global_attention_size
:
1
train_data
:
drop_remainder
:
true
global_batch_size
:
256
input_path
:
gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00000-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00001-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00002-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00003-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00004-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00005-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00006-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00007-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00008-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00009-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00010-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00011-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00012-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00013-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00014-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00015-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00016-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00017-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00018-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00019-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00020-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00021-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00022-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00023-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00024-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00025-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00026-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00027-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00028-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00029-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00030-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00031-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00032-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00033-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00034-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00035-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00036-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00037-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00038-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00039-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00040-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00041-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00042-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00043-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00044-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00045-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00046-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00047-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00048-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00049-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00050-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00051-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00052-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00053-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00054-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00055-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00056-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00057-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00058-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00059-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00060-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00061-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00062-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00063-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00064-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00065-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00066-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00067-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00068-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00069-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00070-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00071-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00072-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00073-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00074-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00075-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00076-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00077-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00078-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00079-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00080-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00081-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00082-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00083-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00084-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00085-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00086-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00087-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00088-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00089-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00090-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00091-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00092-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00093-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00094-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00095-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00096-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00097-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00098-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00099-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00100-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00101-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00102-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00103-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00104-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00105-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00106-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00107-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00108-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00109-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00110-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00111-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00112-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00113-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00114-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00115-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00116-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00117-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00118-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00119-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00120-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00121-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00122-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00123-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00124-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00125-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00126-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00127-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00128-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00129-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00130-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00131-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00132-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00133-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00134-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00135-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00136-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00137-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00138-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00139-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00140-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00141-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00142-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00143-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00144-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00145-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00146-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00147-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00148-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00149-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00150-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00151-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00152-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00153-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00154-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00155-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00156-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00157-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00158-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00159-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00160-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00161-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00162-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00163-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00164-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00165-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00166-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00167-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00168-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00169-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00170-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00171-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00172-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00173-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00174-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00175-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00176-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00177-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00178-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00179-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00180-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00181-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00182-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00183-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00184-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00185-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00186-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00187-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00188-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00189-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00190-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00191-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00192-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00193-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00194-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00195-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00196-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00197-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00198-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00199-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00200-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00201-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00202-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00203-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00204-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00205-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00206-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00207-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00208-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00209-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00210-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00211-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00212-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00213-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00214-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00215-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00216-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00217-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00218-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00219-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00220-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00221-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00222-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00223-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00224-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00225-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00226-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00227-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00228-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00229-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00230-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00231-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00232-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00233-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00234-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00235-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00236-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00237-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00238-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00239-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00240-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00241-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00242-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00243-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00244-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00245-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00246-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00247-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00248-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00249-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00250-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00251-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00252-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00253-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00254-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00255-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00256-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00257-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00258-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00259-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00260-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00261-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00262-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00263-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00264-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00265-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00266-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00267-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00268-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00269-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00270-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00271-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00272-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00273-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00274-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00275-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00276-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00277-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00278-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00279-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00280-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00281-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00282-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00283-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00284-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00285-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00286-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00287-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00288-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00289-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00290-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00291-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00292-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00293-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00294-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00295-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00296-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00297-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00298-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00299-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00300-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00301-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00302-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00303-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00304-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00305-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00306-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00307-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00308-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00309-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00310-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00311-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00312-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00313-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00314-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00315-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00316-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00317-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00318-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00319-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00320-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00321-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00322-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00323-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00324-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00325-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00326-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00327-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00328-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00329-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00330-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00331-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00332-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00333-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00334-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00335-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00336-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00337-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00338-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00339-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00340-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00341-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00342-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00343-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00344-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00345-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00346-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00347-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00348-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00349-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00350-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00351-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00352-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00353-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00354-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00355-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00356-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00357-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00358-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00359-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00360-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00361-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00362-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00363-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00364-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00365-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00366-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00367-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00368-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00369-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00370-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00371-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00372-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00373-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00374-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00375-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00376-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00377-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00378-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00379-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00380-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00381-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00382-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00383-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00384-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00385-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00386-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00387-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00388-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00389-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00390-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00391-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00392-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00393-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00394-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00395-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00396-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00397-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00398-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00399-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00400-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00401-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00402-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00403-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00404-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00405-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00406-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00407-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00408-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00409-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00410-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00411-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00412-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00413-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00414-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00415-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00416-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00417-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00418-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00419-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00420-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00421-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00422-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00423-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00424-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00425-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00426-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00427-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00428-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00429-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00430-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00431-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00432-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00433-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00434-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00435-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00436-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00437-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00438-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00439-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00440-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00441-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00442-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00443-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00444-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00445-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00446-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00447-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00448-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00449-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00450-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00451-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00452-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00453-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00454-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00455-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00456-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00457-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00458-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00459-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00460-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00461-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00462-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00463-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00464-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00465-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00466-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00467-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00468-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00469-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00470-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00471-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00472-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00473-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00474-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00475-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00476-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00477-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00478-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00479-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00480-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00481-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00482-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00483-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00484-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00485-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00486-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00487-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00488-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00489-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00490-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00491-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00492-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00493-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00494-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00495-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00496-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00497-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00498-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00499-of-00500
is_training
:
true
max_predictions_per_seq
:
76
seq_length
:
512
use_next_sentence_label
:
true
use_position_id
:
false
validation_data
:
drop_remainder
:
true
global_batch_size
:
256
input_path
:
TODO
is_training
:
false
max_predictions_per_seq
:
76
seq_length
:
512
use_next_sentence_label
:
true
use_position_id
:
false
trainer
:
checkpoint_interval
:
20000
max_to_keep
:
5
optimizer_config
:
learning_rate
:
polynomial
:
cycle
:
false
decay_steps
:
1000000
end_learning_rate
:
0.0
initial_learning_rate
:
0.0001
power
:
1.0
type
:
polynomial
optimizer
:
type
:
adamw
warmup
:
polynomial
:
power
:
1
warmup_steps
:
10000
type
:
polynomial
steps_per_loop
:
50
summary_interval
:
50
train_steps
:
1000000
validation_interval
:
1000
validation_steps
:
64
official/projects/longformer/longformer.py
0 → 100644
View file @
8b641b13
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Longformer model configurations and instantiation methods."""
import
dataclasses
from
typing
import
List
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.modeling.hyperparams
import
base_config
from
official.nlp.configs
import
encoders
from
official.projects.longformer.longformer_encoder
import
LongformerEncoder
@
dataclasses
.
dataclass
class
LongformerEncoderConfig
(
encoders
.
BertEncoderConfig
):
"""Extra paramerters for Longformer configs.
Attributes:
attention_window: list of ints representing the window size for each layer.
global_attention_size: the size of global attention used for each token.
pad_token_id: the token id for the pad token
"""
attention_window
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
global_attention_size
:
int
=
0
pad_token_id
:
int
=
1
@
base_config
.
bind
(
LongformerEncoderConfig
)
def
get_encoder
(
encoder_cfg
:
LongformerEncoderConfig
):
"""Gets a 'LongformerEncoder' object.
Args:
encoder_cfg: A 'LongformerEncoderConfig'.
Returns:
A encoder object.
"""
encoder
=
LongformerEncoder
(
attention_window
=
encoder_cfg
.
attention_window
,
global_attention_size
=
encoder_cfg
.
global_attention_size
,
vocab_size
=
encoder_cfg
.
vocab_size
,
hidden_size
=
encoder_cfg
.
hidden_size
,
num_layers
=
encoder_cfg
.
num_layers
,
num_attention_heads
=
encoder_cfg
.
num_attention_heads
,
inner_dim
=
encoder_cfg
.
intermediate_size
,
inner_activation
=
tf_utils
.
get_activation
(
encoder_cfg
.
hidden_activation
),
output_dropout
=
encoder_cfg
.
dropout_rate
,
attention_dropout
=
encoder_cfg
.
attention_dropout_rate
,
max_sequence_length
=
encoder_cfg
.
max_position_embeddings
,
type_vocab_size
=
encoder_cfg
.
type_vocab_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
output_range
=
encoder_cfg
.
output_range
,
embedding_width
=
encoder_cfg
.
embedding_size
,
norm_first
=
encoder_cfg
.
norm_first
)
return
encoder
official/projects/longformer/longformer_attention.py
0 → 100644
View file @
8b641b13
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Longformer attention block. Modified From huggingface/transformers."""
# pylint: disable=g-classes-have-attributes
import
math
import
string
import
numpy
as
np
import
tensorflow
as
tf
from
official.modeling.tf_utils
import
get_shape_list
_CHR_IDX
=
string
.
ascii_lowercase
def
_build_attention_equation
(
rank
,
attn_axes
):
"""Builds einsum equations for the attention computation.
Query, key, value inputs after projection are expected to have the shape as:
`(bs, <non-attention dims>, <attention dims>, num_heads, channels)`.
`bs` and `<non-attention dims>` are treated as `<batch dims>`.
The attention operations can be generalized:
(1) Query-key dot product:
`(<batch dims>, <query attention dims>, num_heads, channels), (<batch dims>,
<key attention dims>, num_heads, channels) -> (<batch dims>,
num_heads, <query attention dims>, <key attention dims>)`
(2) Combination:
`(<batch dims>, num_heads, <query attention dims>, <key attention dims>),
(<batch dims>, <value attention dims>, num_heads, channels) -> (<batch dims>,
<query attention dims>, num_heads, channels)`
Args:
rank: Rank of query, key, value tensors.
attn_axes: List/tuple of axes, `[-1, rank)`, that attention will be applied
to.
Returns:
Einsum equations.
"""
target_notation
=
_CHR_IDX
[:
rank
]
# `batch_dims` includes the head dim.
batch_dims
=
tuple
(
np
.
delete
(
range
(
rank
),
attn_axes
+
(
rank
-
1
,)))
letter_offset
=
rank
source_notation
=
""
for
i
in
range
(
rank
):
if
i
in
batch_dims
or
i
==
rank
-
1
:
source_notation
+=
target_notation
[
i
]
else
:
source_notation
+=
_CHR_IDX
[
letter_offset
]
letter_offset
+=
1
product_notation
=
""
.
join
([
target_notation
[
i
]
for
i
in
batch_dims
]
+
[
target_notation
[
i
]
for
i
in
attn_axes
]
+
[
source_notation
[
i
]
for
i
in
attn_axes
])
dot_product_equation
=
f
"
{
source_notation
}
,
{
target_notation
}
->
{
product_notation
}
"
attn_scores_rank
=
len
(
product_notation
)
combine_equation
=
f
"
{
product_notation
}
,
{
source_notation
}
->
{
target_notation
}
"
return
dot_product_equation
,
combine_equation
,
attn_scores_rank
def
_build_proj_equation
(
free_dims
,
bound_dims
,
output_dims
):
"""Builds an einsum equation for projections inside multi-head attention."""
input_str
=
""
kernel_str
=
""
output_str
=
""
bias_axes
=
""
letter_offset
=
0
for
i
in
range
(
free_dims
):
char
=
_CHR_IDX
[
i
+
letter_offset
]
input_str
+=
char
output_str
+=
char
letter_offset
+=
free_dims
for
i
in
range
(
bound_dims
):
char
=
_CHR_IDX
[
i
+
letter_offset
]
input_str
+=
char
kernel_str
+=
char
letter_offset
+=
bound_dims
for
i
in
range
(
output_dims
):
char
=
_CHR_IDX
[
i
+
letter_offset
]
kernel_str
+=
char
output_str
+=
char
bias_axes
+=
char
equation
=
f
"
{
input_str
}
,
{
kernel_str
}
->
{
output_str
}
"
return
equation
,
bias_axes
,
len
(
output_str
)
def
_get_output_shape
(
output_rank
,
known_last_dims
):
return
[
None
]
*
(
output_rank
-
len
(
known_last_dims
))
+
list
(
known_last_dims
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
class
LongformerAttention
(
tf
.
keras
.
layers
.
MultiHeadAttention
):
"""LongformerAttention.
Args:
attention_window: int representing the window size for attention.
layer_id: int of the id of the layer.
global_attention_size: the size of global attention used for each token.
"""
def
__init__
(
self
,
attention_window
,
layer_id
,
global_attention_size
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
_layer_id
=
layer_id
self
.
_attention_window
=
attention_window
assert
(
self
.
_attention_window
%
2
==
0
),
(
f
"`attention_window` for layer
{
self
.
_layer_id
}
has to be an even "
f
"value. Given
{
self
.
attention_window
}
"
)
assert
(
self
.
_attention_window
>
0
),
(
f
"`attention_window` for layer
{
self
.
_layer_id
}
has to be positive. "
f
"Given
{
self
.
attention_window
}
"
)
self
.
_one_sided_attn_window_size
=
self
.
_attention_window
//
2
self
.
global_attention_size
=
global_attention_size
def
_build_from_signature
(
self
,
query
,
value
,
key
=
None
):
"""Builds layers and variables.
Once the method is called, self._built_from_signature will be set to True.
Args:
query: Query tensor or TensorShape.
value: Value tensor or TensorShape.
key: Key tensor or TensorShape.
"""
self
.
_built_from_signature
=
True
if
hasattr
(
query
,
"shape"
):
self
.
_query_shape
=
tf
.
TensorShape
(
query
.
shape
)
else
:
self
.
_query_shape
=
tf
.
TensorShape
(
query
)
if
hasattr
(
value
,
"shape"
):
self
.
_value_shape
=
tf
.
TensorShape
(
value
.
shape
)
else
:
self
.
_value_shape
=
tf
.
TensorShape
(
value
)
if
key
is
None
:
self
.
_key_shape
=
self
.
_value_shape
elif
hasattr
(
key
,
"shape"
):
self
.
_key_shape
=
tf
.
TensorShape
(
key
.
shape
)
else
:
self
.
_key_shape
=
tf
.
TensorShape
(
key
)
common_kwargs
=
dict
(
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
)
# Any setup work performed only once should happen in an `init_scope`
# to avoid creating symbolic Tensors that will later pollute any eager
# operations.
# with tf_utils.maybe_init_scope(self):
# TODO(crickwu): check whether tf_utils.maybe_init_scope(self) (keras)
# is needed.
free_dims
=
self
.
_query_shape
.
rank
-
1
einsum_equation
,
bias_axes
,
output_rank
=
_build_proj_equation
(
free_dims
,
bound_dims
=
1
,
output_dims
=
2
)
self
.
_query_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
einsum_equation
,
output_shape
=
_get_output_shape
(
output_rank
-
1
,
[
self
.
_num_heads
,
self
.
_key_dim
]),
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
name
=
"query"
,
**
common_kwargs
)
self
.
_global_query_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
einsum_equation
,
output_shape
=
_get_output_shape
(
output_rank
-
1
,
[
self
.
_num_heads
,
self
.
_key_dim
]),
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
name
=
"global_query"
,
**
common_kwargs
)
einsum_equation
,
bias_axes
,
output_rank
=
_build_proj_equation
(
self
.
_key_shape
.
rank
-
1
,
bound_dims
=
1
,
output_dims
=
2
)
self
.
_key_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
einsum_equation
,
output_shape
=
_get_output_shape
(
output_rank
-
1
,
[
self
.
_num_heads
,
self
.
_key_dim
]),
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
name
=
"key"
,
**
common_kwargs
)
self
.
_global_key_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
einsum_equation
,
output_shape
=
_get_output_shape
(
output_rank
-
1
,
[
self
.
_num_heads
,
self
.
_key_dim
]),
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
name
=
"global_key"
,
**
common_kwargs
)
einsum_equation
,
bias_axes
,
output_rank
=
_build_proj_equation
(
self
.
_value_shape
.
rank
-
1
,
bound_dims
=
1
,
output_dims
=
2
)
self
.
_value_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
einsum_equation
,
output_shape
=
_get_output_shape
(
output_rank
-
1
,
[
self
.
_num_heads
,
self
.
_value_dim
]),
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
name
=
"value"
,
**
common_kwargs
)
self
.
_global_value_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
einsum_equation
,
output_shape
=
_get_output_shape
(
output_rank
-
1
,
[
self
.
_num_heads
,
self
.
_value_dim
]),
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
name
=
"global_value"
,
**
common_kwargs
)
# Builds the attention computations for multi-head dot product attention.
# These computations could be wrapped into the keras attention layer once
# it support mult-head einsum computations.
self
.
_build_attention
(
output_rank
)
self
.
_global_dropout_layer
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout
)
# self._output_dense = self._make_output_dense(
# free_dims, common_kwargs, "attention_output")
self
.
_output_dense
=
tf
.
keras
.
layers
.
Dense
(
units
=
self
.
_num_heads
*
self
.
_key_dim
,
name
=
"dense"
,
**
common_kwargs
)
def
call
(
self
,
hidden_states
,
attention_mask
=
None
,
is_index_masked
=
None
,
is_index_global_attn
=
None
,
training
=
None
):
"""Applies Dot-product attention with query, key, value tensors.
This function defines the computation inside `call` with projected
multi-head Q, K, V inputs. Users can override this function for customized
attention implementation.
Args:
hidden_states: inputs for generating query, key and value tensors.
attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
attention to certain positions.
is_index_masked: boolean indicating whether the index is masked.
is_index_global_attn: boolean indicating whether the index is global
attention.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
Returns:
attention_output: Multi-headed outputs of attention computation.
"""
if
not
self
.
_built_from_signature
:
self
.
_build_from_signature
(
query
=
hidden_states
,
value
=
hidden_states
,
key
=
hidden_states
)
# N = `num_attention_heads`
# H = `size_per_head`
# `query` = [B, T, N ,H]
query
=
self
.
_query_dense
(
hidden_states
)
# `key` = [B, S, N, H]
key
=
self
.
_key_dense
(
hidden_states
)
# `value` = [B, S, N, H]
value
=
self
.
_value_dense
(
hidden_states
)
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query
=
tf
.
multiply
(
query
,
1.0
/
math
.
sqrt
(
float
(
self
.
_key_dim
)))
batch_size
,
seq_len
,
num_heads
,
head_dim
=
get_shape_list
(
query
)
# attn_probs = (batch_size, seq_len, num_heads, window*2+1)
attn_scores
=
self
.
_sliding_chunks_query_key_matmul
(
query
,
key
,
self
.
_one_sided_attn_window_size
)
# diagonal mask with zeros everywhere and -inf inplace of padding
diagonal_mask
=
self
.
_sliding_chunks_query_key_matmul
(
tf
.
ones
(
get_shape_list
(
attention_mask
)),
attention_mask
,
self
.
_one_sided_attn_window_size
,
)
# pad local attention probs
attn_scores
+=
diagonal_mask
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
get_shape_list
(
attn_scores
),
[
batch_size
,
seq_len
,
self
.
_num_heads
,
self
.
_one_sided_attn_window_size
*
2
+
1
],
message
=
f
"attn_probs should be of size "
f
"(
{
batch_size
}
,
{
seq_len
}
,
{
num_heads
}
, "
f
"
{
self
.
_one_sided_attn_window_size
*
2
+
1
}
),"
f
" but is of size
{
get_shape_list
(
attn_scores
)
}
"
,
)
# compute global attn indices required through out forward fn
(
max_num_global_attn_indices
,
is_index_global_attn_nonzero
,
is_local_index_global_attn_nonzero
,
is_local_index_no_global_attn_nonzero
,
)
=
self
.
_get_global_attn_indices
(
is_index_global_attn
,
self
.
global_attention_size
)
# this function is only relevant for global attention
if
self
.
global_attention_size
>
0
:
attn_scores
=
self
.
_concat_with_global_key_attn_probs
(
attn_scores
=
attn_scores
,
query_vectors
=
query
,
key_vectors
=
key
,
max_num_global_attn_indices
=
max_num_global_attn_indices
,
is_index_global_attn_nonzero
=
is_index_global_attn_nonzero
,
is_local_index_global_attn_nonzero
=
is_local_index_global_attn_nonzero
,
is_local_index_no_global_attn_nonzero
=
is_local_index_no_global_attn_nonzero
,
)
else
:
pass
attn_probs
=
tf
.
nn
.
softmax
(
attn_scores
,
axis
=-
1
)
# softmax sometimes inserts NaN if all positions are masked,
# replace them with 0
# Make sure to create a mask with the proper shape:
# if is_global_attn==True => [batch_size, seq_len, self.num_heads,
# self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
# if is_global_attn==False => [batch_size, seq_len, self.num_heads,
# self.one_sided_attn_window_size * 2 + 1]
if
self
.
global_attention_size
>
0
:
masked_index
=
tf
.
tile
(
is_index_masked
[:,
:,
None
,
None
],
(
1
,
1
,
self
.
_num_heads
,
self
.
_one_sided_attn_window_size
*
2
+
max_num_global_attn_indices
+
1
),
)
else
:
masked_index
=
tf
.
tile
(
is_index_masked
[:,
:,
None
,
None
],
(
1
,
1
,
self
.
_num_heads
,
self
.
_one_sided_attn_window_size
*
2
+
1
),
)
attn_probs
=
tf
.
where
(
masked_index
,
tf
.
zeros
(
get_shape_list
(
masked_index
),
dtype
=
attn_probs
.
dtype
),
attn_probs
,
)
layer_head_mask
=
None
if
layer_head_mask
is
not
None
:
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
get_shape_list
(
layer_head_mask
),
[
self
.
_num_heads
],
message
=
f
"Head mask for a single layer should be of size "
f
"
{
(
self
.
_num_heads
)
}
, but is "
f
"
{
get_shape_list
(
layer_head_mask
)
}
"
,
)
attn_probs
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
1
,
-
1
,
1
))
*
attn_probs
# apply dropout
attn_probs
=
self
.
_dropout_layer
(
attn_probs
,
training
=
training
)
value_vectors
=
tf
.
reshape
(
value
,
(
batch_size
,
seq_len
,
self
.
_num_heads
,
self
.
_key_dim
))
# if global attention, compute sum of global and local attn
if
self
.
global_attention_size
>
0
:
attn_output
=
self
.
_compute_attn_output_with_global_indices
(
value_vectors
=
value_vectors
,
attn_probs
=
attn_probs
,
max_num_global_attn_indices
=
max_num_global_attn_indices
,
is_index_global_attn_nonzero
=
is_index_global_attn_nonzero
,
is_local_index_global_attn_nonzero
=
is_local_index_global_attn_nonzero
,
)
else
:
attn_output
=
self
.
_sliding_chunks_matmul_attn_probs_value
(
attn_probs
,
value_vectors
,
self
.
_one_sided_attn_window_size
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
get_shape_list
(
attn_output
),
[
batch_size
,
seq_len
,
self
.
_num_heads
,
head_dim
],
message
=
"Unexpected size"
,
)
attn_output
=
tf
.
reshape
(
attn_output
,
(
batch_size
,
seq_len
,
self
.
_num_heads
*
self
.
_key_dim
))
# FIXME
# compute value for global attention and overwrite to attention output
# TODO(crickwu): remove the redundant computation
if
self
.
global_attention_size
>
0
:
attn_output
,
global_attn_probs
=
self
.
_compute_global_attn_output_from_hidden
(
# pylint: disable=unused-variable
attn_output
=
attn_output
,
hidden_states
=
hidden_states
,
max_num_global_attn_indices
=
max_num_global_attn_indices
,
layer_head_mask
=
layer_head_mask
,
is_local_index_global_attn_nonzero
=
is_local_index_global_attn_nonzero
,
is_index_global_attn_nonzero
=
is_index_global_attn_nonzero
,
is_local_index_no_global_attn_nonzero
=
is_local_index_no_global_attn_nonzero
,
is_index_masked
=
is_index_masked
,
training
=
training
,
)
else
:
global_attn_probs
=
tf
.
zeros
(
(
batch_size
,
self
.
_num_heads
,
max_num_global_attn_indices
,
seq_len
))
# make sure that local attention probabilities are set to 0 for indices of
# global attn
if
self
.
global_attention_size
>
0
:
masked_global_attn_index
=
tf
.
tile
(
is_index_global_attn
[:,
:,
None
,
None
],
(
1
,
1
,
self
.
_num_heads
,
self
.
_one_sided_attn_window_size
*
2
+
max_num_global_attn_indices
+
1
),
)
else
:
masked_global_attn_index
=
tf
.
tile
(
is_index_global_attn
[:,
:,
None
,
None
],
(
1
,
1
,
self
.
_num_heads
,
self
.
_one_sided_attn_window_size
*
2
+
1
),
)
attn_probs
=
tf
.
where
(
masked_global_attn_index
,
tf
.
zeros
(
get_shape_list
(
masked_global_attn_index
),
dtype
=
attn_probs
.
dtype
),
attn_probs
,
)
# we can return extra information here
# (attn_output, attn_probs, global_attn_probs)
return
attn_output
def
get_config
(
self
):
config
=
{
"layer_id"
:
self
.
_layer_id
,
"attention_window"
:
self
.
_one_sided_attn_window_size
,
}
base_config
=
super
().
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
_sliding_chunks_query_key_matmul
(
self
,
query
,
key
,
window_overlap
):
"""Matrix multiplication of query and key tensors.
This multiplication uses a sliding window attention pattern.
This implementation splits the input into overlapping chunks of size
2w (e.g. 512 for pretrained Longformer) with an overlap of size
window_overlap.
Args:
query: query tensor.
key: key tensor.
window_overlap: int.
Returns:
diagonal_attention_scores: tensor.
"""
batch_size
,
seq_len
,
num_heads
,
head_dim
=
get_shape_list
(
query
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
seq_len
%
(
window_overlap
*
2
),
0
,
message
=
f
"Sequence length should be multiple of
{
window_overlap
*
2
}
. "
f
"Given
{
seq_len
}
"
,
)
tf
.
debugging
.
assert_equal
(
get_shape_list
(
query
),
get_shape_list
(
key
),
message
=
f
"Shape of query and key should be equal, but got query: "
f
"
{
get_shape_list
(
query
)
}
and key:
{
get_shape_list
(
key
)
}
"
,
)
chunks_count
=
seq_len
//
window_overlap
-
1
# group batch_size and num_heads dimensions into one,
# then chunk seq_len into chunks of size window_overlap * 2
query
=
tf
.
reshape
(
tf
.
transpose
(
query
,
(
0
,
2
,
1
,
3
)),
(
batch_size
*
num_heads
,
seq_len
,
head_dim
),
)
key
=
tf
.
reshape
(
tf
.
transpose
(
key
,
(
0
,
2
,
1
,
3
)),
(
batch_size
*
num_heads
,
seq_len
,
head_dim
))
chunked_query
=
self
.
_chunk
(
query
,
window_overlap
)
chunked_key
=
self
.
_chunk
(
key
,
window_overlap
)
# matrix multiplication
# bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap
chunked_query
=
tf
.
cast
(
chunked_query
,
dtype
=
chunked_key
.
dtype
)
chunked_attention_scores
=
tf
.
einsum
(
"bcxd,bcyd->bcxy"
,
chunked_query
,
chunked_key
)
# multiply
# convert diagonals into columns
paddings
=
tf
.
convert_to_tensor
([[
0
,
0
],
[
0
,
0
],
[
0
,
1
],
[
0
,
0
]])
diagonal_chunked_attention_scores
=
self
.
_pad_and_transpose_last_two_dims
(
chunked_attention_scores
,
paddings
)
# allocate space for the overall attention matrix where the chunks are
# combined. The last dimension
# has (window_overlap * 2 + 1) columns. The first (window_overlap) columns
# are the window_overlap lower triangles (attention from a word to
# window_overlap previous words). The following column is attention score
# from each word to itself, then
# followed by window_overlap columns for the upper triangle.
# copy parts from diagonal_chunked_attention_scores into the combined matrix
# of attentions - copying the main diagonal and the upper triangle
# TODO(crickwu): This code is most likely not very efficient and should be
# improved.
diagonal_attn_scores_up_triang
=
tf
.
concat
(
[
diagonal_chunked_attention_scores
[:,
:,
:
window_overlap
,
:
window_overlap
+
1
],
diagonal_chunked_attention_scores
[:,
-
1
:,
window_overlap
:,
:
window_overlap
+
1
],
],
axis
=
1
,
)
# - copying the lower triangle
diagonal_attn_scores_low_triang
=
tf
.
concat
(
[
tf
.
zeros
(
(
batch_size
*
num_heads
,
1
,
window_overlap
,
window_overlap
),
dtype
=
diagonal_chunked_attention_scores
.
dtype
,
),
diagonal_chunked_attention_scores
[:,
:,
-
(
window_overlap
+
1
):
-
1
,
window_overlap
+
1
:],
],
axis
=
1
,
)
diagonal_attn_scores_first_chunk
=
tf
.
concat
(
[
tf
.
roll
(
diagonal_chunked_attention_scores
,
shift
=
[
1
,
window_overlap
],
axis
=
[
2
,
3
],
)[:,
:,
:
window_overlap
,
:
window_overlap
],
tf
.
zeros
(
(
batch_size
*
num_heads
,
1
,
window_overlap
,
window_overlap
),
dtype
=
diagonal_chunked_attention_scores
.
dtype
,
),
],
axis
=
1
,
)
first_chunk_mask
=
(
tf
.
tile
(
tf
.
range
(
chunks_count
+
1
)[
None
,
:,
None
,
None
],
(
batch_size
*
num_heads
,
1
,
window_overlap
,
window_overlap
),
)
<
1
)
diagonal_attn_scores_low_triang
=
tf
.
where
(
first_chunk_mask
,
diagonal_attn_scores_first_chunk
,
diagonal_attn_scores_low_triang
,
)
# merging upper and lower triangle
diagonal_attention_scores
=
tf
.
concat
(
[
diagonal_attn_scores_low_triang
,
diagonal_attn_scores_up_triang
],
axis
=-
1
)
# separate batch_size and num_heads dimensions again
diagonal_attention_scores
=
tf
.
transpose
(
tf
.
reshape
(
diagonal_attention_scores
,
(
batch_size
,
num_heads
,
seq_len
,
2
*
window_overlap
+
1
),
),
(
0
,
2
,
1
,
3
),
)
diagonal_attention_scores
=
self
.
_mask_invalid_locations
(
diagonal_attention_scores
,
window_overlap
)
return
diagonal_attention_scores
@
staticmethod
def
_mask_invalid_locations
(
input_tensor
,
window_overlap
):
# create correct upper triangle bool mask
mask_2d_upper
=
tf
.
reverse
(
tf
.
linalg
.
band_part
(
tf
.
ones
(
shape
=
(
window_overlap
,
window_overlap
+
1
)),
-
1
,
0
),
axis
=
[
0
],
)
# pad to full matrix
padding
=
tf
.
convert_to_tensor
(
[[
0
,
get_shape_list
(
input_tensor
)[
1
]
-
window_overlap
],
[
0
,
get_shape_list
(
input_tensor
)[
3
]
-
window_overlap
-
1
]])
# create lower mask
mask_2d
=
tf
.
pad
(
mask_2d_upper
,
padding
)
# combine with upper mask
mask_2d
=
mask_2d
+
tf
.
reverse
(
mask_2d
,
axis
=
[
0
,
1
])
# broadcast to full matrix
mask_4d
=
tf
.
tile
(
mask_2d
[
None
,
:,
None
,
:],
(
get_shape_list
(
input_tensor
)[
0
],
1
,
1
,
1
))
# inf tensor used for masking
inf_tensor
=
-
float
(
"inf"
)
*
tf
.
ones_like
(
input_tensor
)
# mask
input_tensor
=
tf
.
where
(
tf
.
math
.
greater
(
mask_4d
,
0
),
inf_tensor
,
input_tensor
)
return
input_tensor
def
_sliding_chunks_matmul_attn_probs_value
(
self
,
attn_probs
,
value
,
window_overlap
):
"""Same as _sliding_chunks_query_key_matmul but for attn_probs and value."""
batch_size
,
seq_len
,
num_heads
,
head_dim
=
get_shape_list
(
value
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
seq_len
%
(
window_overlap
*
2
),
0
,
message
=
"Seq_len has to be multiple of 2 * window_overlap"
,
)
tf
.
debugging
.
assert_equal
(
get_shape_list
(
attn_probs
)[:
3
],
get_shape_list
(
value
)[:
3
],
message
=
"value and attn_probs must have same dims (except head_dim)"
,
)
tf
.
debugging
.
assert_equal
(
get_shape_list
(
attn_probs
)[
3
],
2
*
window_overlap
+
1
,
message
=
"attn_probs last dim has to be 2 * window_overlap + 1"
,
)
chunks_count
=
seq_len
//
window_overlap
-
1
# group batch_size and num_heads dimensions into one, then chunk seq_len
# into chunks of size 2 window overlap
chunked_attn_probs
=
tf
.
reshape
(
tf
.
transpose
(
attn_probs
,
(
0
,
2
,
1
,
3
)),
(
batch_size
*
num_heads
,
seq_len
//
window_overlap
,
window_overlap
,
2
*
window_overlap
+
1
,
),
)
# group batch_size and num_heads dimensions into one
value
=
tf
.
reshape
(
tf
.
transpose
(
value
,
(
0
,
2
,
1
,
3
)),
(
batch_size
*
num_heads
,
seq_len
,
head_dim
),
)
# pad seq_len with w at the beginning of the sequence and another window
# overlap at the end
paddings
=
tf
.
convert_to_tensor
([[
0
,
0
],
[
window_overlap
,
window_overlap
],
[
0
,
0
]])
padded_value
=
tf
.
pad
(
value
,
paddings
,
constant_values
=-
1
)
# chunk padded_value into chunks of size 3 window overlap and an overlap of
# size window overlap
frame_size
=
3
*
window_overlap
*
head_dim
frame_hop_size
=
(
get_shape_list
(
padded_value
)[
1
]
*
head_dim
-
frame_size
)
//
chunks_count
chunked_value
=
tf
.
signal
.
frame
(
tf
.
reshape
(
padded_value
,
(
batch_size
*
num_heads
,
-
1
)),
frame_size
,
frame_hop_size
,
)
chunked_value
=
tf
.
reshape
(
chunked_value
,
(
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
),
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
get_shape_list
(
chunked_value
),
[
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
],
message
=
"Chunked value has the wrong shape"
,
)
chunked_attn_probs
=
self
.
_pad_and_diagonalize
(
chunked_attn_probs
)
context
=
tf
.
einsum
(
"bcwd,bcdh->bcwh"
,
chunked_attn_probs
,
chunked_value
)
context
=
tf
.
transpose
(
tf
.
reshape
(
context
,
(
batch_size
,
num_heads
,
seq_len
,
head_dim
)),
(
0
,
2
,
1
,
3
),
)
return
context
@
staticmethod
def
_pad_and_transpose_last_two_dims
(
hidden_states_padded
,
paddings
):
"""Pads rows and then flips rows and columns."""
hidden_states_padded
=
tf
.
pad
(
hidden_states_padded
,
paddings
)
# padding value is not important because it will be overwritten
batch_size
,
chunk_size
,
seq_length
,
hidden_dim
=
get_shape_list
(
hidden_states_padded
)
hidden_states_padded
=
tf
.
reshape
(
hidden_states_padded
,
(
batch_size
,
chunk_size
,
hidden_dim
,
seq_length
))
return
hidden_states_padded
@
staticmethod
def
_pad_and_diagonalize
(
chunked_hidden_states
):
"""Shifts every row 1 step right, converting columns into diagonals.
Example::
chunked_hidden_states: [ 0.4983, 2.6918, -0.0071, 1.0492,
-1.8348, 0.7672, 0.2986, 0.0285,
-0.7584, 0.4206, -0.0405, 0.1599,
2.0514, -1.1600, 0.5372, 0.2629 ]
window_overlap = num_rows = 4
(pad & diagonalize) =>
[ 0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000
0.0000, -1.8348, 0.7672, 0.2986, 0.0285, 0.0000, 0.0000
0.0000, 0.0000, -0.7584, 0.4206, -0.0405, 0.1599, 0.0000
0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ]
Args:
chunked_hidden_states: tensor.
Returns:
padded_hidden_stategs: tensor.
"""
total_num_heads
,
num_chunks
,
window_overlap
,
hidden_dim
=
get_shape_list
(
chunked_hidden_states
)
paddings
=
tf
.
convert_to_tensor
([[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
window_overlap
+
1
]])
chunked_hidden_states
=
tf
.
pad
(
chunked_hidden_states
,
paddings
)
chunked_hidden_states
=
tf
.
reshape
(
chunked_hidden_states
,
(
total_num_heads
,
num_chunks
,
-
1
))
chunked_hidden_states
=
chunked_hidden_states
[:,
:,
:
-
window_overlap
]
chunked_hidden_states
=
tf
.
reshape
(
chunked_hidden_states
,
(
total_num_heads
,
num_chunks
,
window_overlap
,
window_overlap
+
hidden_dim
),
)
chunked_hidden_states
=
chunked_hidden_states
[:,
:,
:,
:
-
1
]
return
chunked_hidden_states
@
staticmethod
def
_chunk
(
hidden_states
,
window_overlap
):
"""convert into overlapping chunks. Chunk size = 2w, overlap size = w."""
batch_size
,
seq_length
,
hidden_dim
=
get_shape_list
(
hidden_states
)
num_output_chunks
=
2
*
(
seq_length
//
(
2
*
window_overlap
))
-
1
# define frame size and frame stride (similar to convolution)
frame_hop_size
=
window_overlap
*
hidden_dim
frame_size
=
2
*
frame_hop_size
hidden_states
=
tf
.
reshape
(
hidden_states
,
(
batch_size
,
seq_length
*
hidden_dim
))
# chunk with overlap
chunked_hidden_states
=
tf
.
signal
.
frame
(
hidden_states
,
frame_size
,
frame_hop_size
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
get_shape_list
(
chunked_hidden_states
),
[
batch_size
,
num_output_chunks
,
frame_size
],
message
=
f
"Make sure chunking is correctly applied. `Chunked hidden "
f
"states should have output dimension"
f
"
{
[
batch_size
,
frame_size
,
num_output_chunks
]
}
, but got "
f
"
{
get_shape_list
(
chunked_hidden_states
)
}
."
,
)
chunked_hidden_states
=
tf
.
reshape
(
chunked_hidden_states
,
(
batch_size
,
num_output_chunks
,
2
*
window_overlap
,
hidden_dim
),
)
return
chunked_hidden_states
@
staticmethod
def
_get_global_attn_indices
(
is_index_global_attn
,
global_attention_size
):
"""Computes global attn indices required throughout forward pass."""
# All global attention size are fixed through global_attention_size
batch_size
,
_
=
get_shape_list
(
is_index_global_attn
)
max_num_global_attn_indices
=
global_attention_size
row_indices
=
tf
.
range
(
batch_size
)
row_indices
=
tf
.
repeat
(
tf
.
expand_dims
(
row_indices
,
axis
=
0
),
repeats
=
[
global_attention_size
],
axis
=
0
)
row_indices
=
tf
.
reshape
(
row_indices
,
(
batch_size
*
global_attention_size
,
1
))
col_indices
=
tf
.
range
(
global_attention_size
)
col_indices
=
tf
.
repeat
(
tf
.
expand_dims
(
col_indices
,
axis
=
1
),
repeats
=
[
batch_size
],
axis
=
0
)
is_index_global_attn_nonzero
=
tf
.
concat
((
row_indices
,
col_indices
),
axis
=
1
)
# this is actually same as `is_index_global_attn_nonzero`,
# since we assume all global attention are the same size
is_local_index_global_attn_nonzero
=
tf
.
concat
((
row_indices
,
col_indices
),
axis
=
1
)
# empty tensor
is_local_index_no_global_attn_nonzero
=
tf
.
reshape
(
tf
.
expand_dims
(
tf
.
range
(
0
),
axis
=
1
),
(
0
,
2
))
return
(
max_num_global_attn_indices
,
is_index_global_attn_nonzero
,
is_local_index_global_attn_nonzero
,
is_local_index_no_global_attn_nonzero
,
)
def
_concat_with_global_key_attn_probs
(
self
,
attn_scores
,
key_vectors
,
query_vectors
,
max_num_global_attn_indices
,
is_index_global_attn_nonzero
,
is_local_index_global_attn_nonzero
,
is_local_index_no_global_attn_nonzero
,
):
batch_size
=
get_shape_list
(
key_vectors
)[
0
]
# select global key vectors
global_key_vectors
=
tf
.
gather_nd
(
key_vectors
,
is_index_global_attn_nonzero
)
# create only global key vectors
key_vectors_only_global
=
tf
.
scatter_nd
(
is_local_index_global_attn_nonzero
,
global_key_vectors
,
shape
=
(
batch_size
,
max_num_global_attn_indices
,
self
.
_num_heads
,
self
.
_key_dim
,
),
)
# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
attn_probs_from_global_key
=
tf
.
einsum
(
"blhd,bshd->blhs"
,
query_vectors
,
key_vectors_only_global
)
# (batch_size, max_num_global_attn_indices, seq_len, num_heads)
attn_probs_from_global_key_trans
=
tf
.
transpose
(
attn_probs_from_global_key
,
(
0
,
3
,
1
,
2
))
mask_shape
=
(
get_shape_list
(
is_local_index_no_global_attn_nonzero
)[
0
],)
+
tuple
(
get_shape_list
(
attn_probs_from_global_key_trans
)[
-
2
:])
mask
=
tf
.
ones
(
mask_shape
)
*
-
10000.0
mask
=
tf
.
cast
(
mask
,
dtype
=
attn_probs_from_global_key_trans
.
dtype
)
# scatter mask
attn_probs_from_global_key_trans
=
tf
.
tensor_scatter_nd_update
(
attn_probs_from_global_key_trans
,
is_local_index_no_global_attn_nonzero
,
mask
,
)
# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
attn_probs_from_global_key
=
tf
.
transpose
(
attn_probs_from_global_key_trans
,
(
0
,
2
,
3
,
1
))
# concat to attn_probs
# (batch_size, seq_len, num_heads, extra attention count + 2*window+1)
attn_scores
=
tf
.
concat
((
attn_probs_from_global_key
,
attn_scores
),
axis
=-
1
)
return
attn_scores
def
_compute_attn_output_with_global_indices
(
self
,
value_vectors
,
attn_probs
,
max_num_global_attn_indices
,
is_index_global_attn_nonzero
,
is_local_index_global_attn_nonzero
,
):
batch_size
=
get_shape_list
(
attn_probs
)[
0
]
# cut local attn probs to global only
attn_probs_only_global
=
attn_probs
[:,
:,
:,
:
max_num_global_attn_indices
]
# select global value vectors
global_value_vectors
=
tf
.
gather_nd
(
value_vectors
,
is_index_global_attn_nonzero
)
# create only global value vectors
value_vectors_only_global
=
tf
.
scatter_nd
(
is_local_index_global_attn_nonzero
,
global_value_vectors
,
shape
=
(
batch_size
,
max_num_global_attn_indices
,
self
.
_num_heads
,
self
.
_key_dim
,
),
)
# compute attn output only global
attn_output_only_global
=
tf
.
einsum
(
"blhs,bshd->blhd"
,
attn_probs_only_global
,
value_vectors_only_global
)
# reshape attn probs
attn_probs_without_global
=
attn_probs
[:,
:,
:,
max_num_global_attn_indices
:]
# compute attn output with global
attn_output_without_global
=
self
.
_sliding_chunks_matmul_attn_probs_value
(
attn_probs_without_global
,
value_vectors
,
self
.
_one_sided_attn_window_size
)
return
attn_output_only_global
+
attn_output_without_global
def
_compute_global_attn_output_from_hidden
(
self
,
attn_output
,
hidden_states
,
max_num_global_attn_indices
,
layer_head_mask
,
is_local_index_global_attn_nonzero
,
is_index_global_attn_nonzero
,
is_local_index_no_global_attn_nonzero
,
is_index_masked
,
training
,
):
batch_size
,
seq_len
=
get_shape_list
(
hidden_states
)[:
2
]
# prepare global hidden states
global_attn_hidden_states
=
tf
.
gather_nd
(
hidden_states
,
is_index_global_attn_nonzero
)
global_attn_hidden_states
=
tf
.
scatter_nd
(
is_local_index_global_attn_nonzero
,
global_attn_hidden_states
,
shape
=
(
batch_size
,
max_num_global_attn_indices
,
self
.
_num_heads
*
self
.
_key_dim
),
)
# global key, query, value
global_query_vectors_only_global
=
self
.
_global_query_dense
(
global_attn_hidden_states
)
global_key_vectors
=
self
.
_global_key_dense
(
hidden_states
)
global_value_vectors
=
self
.
_global_value_dense
(
hidden_states
)
# normalize
global_query_vectors_only_global
/=
tf
.
math
.
sqrt
(
tf
.
cast
(
self
.
_key_dim
,
dtype
=
global_query_vectors_only_global
.
dtype
))
global_query_vectors_only_global
=
self
.
reshape_and_transpose
(
global_query_vectors_only_global
,
batch_size
)
global_key_vectors
=
self
.
reshape_and_transpose
(
global_key_vectors
,
batch_size
)
global_value_vectors
=
self
.
reshape_and_transpose
(
global_value_vectors
,
batch_size
)
# compute attn scores
global_attn_scores
=
tf
.
matmul
(
global_query_vectors_only_global
,
global_key_vectors
,
transpose_b
=
True
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
get_shape_list
(
global_attn_scores
),
[
batch_size
*
self
.
_num_heads
,
max_num_global_attn_indices
,
seq_len
],
message
=
f
"global_attn_scores have the wrong size. Size should be"
f
"
{
(
batch_size
*
self
.
_num_heads
,
max_num_global_attn_indices
,
seq_len
)
}
, "
f
"but is
{
get_shape_list
(
global_attn_scores
)
}
."
,
)
global_attn_scores
=
tf
.
reshape
(
global_attn_scores
,
(
batch_size
,
self
.
_num_heads
,
max_num_global_attn_indices
,
seq_len
),
)
global_attn_scores_trans
=
tf
.
transpose
(
global_attn_scores
,
(
0
,
2
,
1
,
3
))
mask_shape
=
(
get_shape_list
(
is_local_index_no_global_attn_nonzero
)[
0
],
)
+
tuple
(
get_shape_list
(
global_attn_scores_trans
)[
-
2
:])
global_attn_mask
=
tf
.
ones
(
mask_shape
)
*
-
10000.0
global_attn_mask
=
tf
.
cast
(
global_attn_mask
,
dtype
=
global_attn_scores_trans
.
dtype
)
# scatter mask
global_attn_scores_trans
=
tf
.
tensor_scatter_nd_update
(
global_attn_scores_trans
,
is_local_index_no_global_attn_nonzero
,
global_attn_mask
,
)
global_attn_scores
=
tf
.
transpose
(
global_attn_scores_trans
,
(
0
,
2
,
1
,
3
))
# mask global attn scores
attn_mask
=
tf
.
tile
(
is_index_masked
[:,
None
,
None
,
:],
(
1
,
get_shape_list
(
global_attn_scores
)[
1
],
1
,
1
))
global_attn_scores
=
tf
.
where
(
attn_mask
,
-
10000.0
,
global_attn_scores
)
global_attn_scores
=
tf
.
reshape
(
global_attn_scores
,
(
batch_size
*
self
.
_num_heads
,
max_num_global_attn_indices
,
seq_len
),
)
# compute global attn probs
global_attn_probs_float
=
tf
.
nn
.
softmax
(
global_attn_scores
,
axis
=-
1
)
# apply layer head masking
if
layer_head_mask
is
not
None
:
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
get_shape_list
(
layer_head_mask
),
[
self
.
_num_heads
],
message
=
f
"Head mask for a single layer should be of size "
f
"
{
(
self
.
_num_heads
)
}
, but is
{
get_shape_list
(
layer_head_mask
)
}
"
,
)
global_attn_probs_float
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
global_attn_probs_float
,
(
batch_size
,
self
.
_num_heads
,
max_num_global_attn_indices
,
seq_len
))
global_attn_probs_float
=
tf
.
reshape
(
global_attn_probs_float
,
(
batch_size
*
self
.
_num_heads
,
max_num_global_attn_indices
,
seq_len
))
# dropout
global_attn_probs
=
self
.
_global_dropout_layer
(
global_attn_probs_float
,
training
=
training
)
# global attn output
global_attn_output
=
tf
.
matmul
(
global_attn_probs
,
global_value_vectors
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
get_shape_list
(
global_attn_output
),
[
batch_size
*
self
.
_num_heads
,
max_num_global_attn_indices
,
self
.
_key_dim
],
message
=
f
"global_attn_output tensor has the wrong size. Size should be "
f
"
{
(
batch_size
*
self
.
_num_heads
,
max_num_global_attn_indices
,
self
.
_key_dim
)
}
, "
f
"but is
{
get_shape_list
(
global_attn_output
)
}
."
,
)
global_attn_output
=
tf
.
reshape
(
global_attn_output
,
(
batch_size
,
self
.
_num_heads
,
max_num_global_attn_indices
,
self
.
_key_dim
),
)
# get only non zero global attn output
nonzero_global_attn_output
=
tf
.
gather_nd
(
tf
.
transpose
(
global_attn_output
,
(
0
,
2
,
1
,
3
)),
is_local_index_global_attn_nonzero
,
)
nonzero_global_attn_output
=
tf
.
reshape
(
nonzero_global_attn_output
,
(
get_shape_list
(
is_local_index_global_attn_nonzero
)[
0
],
-
1
),
)
# overwrite values with global attention
attn_output
=
tf
.
tensor_scatter_nd_update
(
attn_output
,
is_index_global_attn_nonzero
,
nonzero_global_attn_output
)
global_attn_probs
=
tf
.
reshape
(
global_attn_probs
,
(
batch_size
,
self
.
_num_heads
,
max_num_global_attn_indices
,
seq_len
))
attn_output
=
self
.
_output_dense
(
attn_output
)
return
attn_output
,
global_attn_probs
def
reshape_and_transpose
(
self
,
vector
,
batch_size
):
return
tf
.
reshape
(
tf
.
transpose
(
tf
.
reshape
(
vector
,
(
batch_size
,
-
1
,
self
.
_num_heads
,
self
.
_key_dim
)),
(
0
,
2
,
1
,
3
),
),
(
batch_size
*
self
.
_num_heads
,
-
1
,
self
.
_key_dim
),
)
official/projects/longformer/longformer_attention_test.py
0 → 100644
View file @
8b641b13
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.projects.longformer.longformer_attention."""
import
numpy
as
np
import
tensorflow
as
tf
from
official.modeling.tf_utils
import
get_shape_list
from
official.projects.longformer
import
longformer_attention
def
_create_mock_attention_data
(
num_heads
,
key_dim
,
value_dim
,
q_seq_length
,
kv_seq_length
,
batch_size
,
include_mask
=
False
):
"""Creates mock testing data.
Args:
num_heads: `int`, Number of attention heads.
key_dim: `int`, Size of query head.
value_dim: `int`, Size of key, value dim.
q_seq_length: `int`, query sequence length of the input.
kv_seq_length: `int`, key, value sequence length of the input.
batch_size: `int`, the batch size.
include_mask: optional `bool`, whether or not to include mask data.
Returns:
A dictionary with `str` as keys and `Tensor` as values.
"""
query_shape
=
(
batch_size
,
q_seq_length
,
key_dim
)
value_shape
=
(
batch_size
,
kv_seq_length
,
value_dim
)
data
=
dict
(
query
=
tf
.
random
.
normal
(
shape
=
query_shape
),
value
=
tf
.
random
.
normal
(
shape
=
value_shape
),
key
=
tf
.
random
.
normal
(
shape
=
value_shape
))
total_seq_length
=
kv_seq_length
if
include_mask
:
mask_shape
=
(
batch_size
,
num_heads
,
q_seq_length
,
total_seq_length
)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
mask_shape
).
astype
(
'float32'
)
mask_data
=
dict
(
attention_mask
=
mask_data
)
data
.
update
(
mask_data
)
return
data
class
LongformerAttentionTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
(
LongformerAttentionTest
,
self
).
setUp
()
np
.
random
.
seed
(
0
)
tf
.
random
.
set_seed
(
0
)
def
_get_hidden_states
(
self
):
return
tf
.
convert_to_tensor
(
[[
[
4.98332758e-01
,
2.69175139e00
,
-
7.08081422e-03
,
1.04915401e00
,
-
1.83476661e00
,
7.67220476e-01
,
2.98580543e-01
,
2.84803992e-02
,
],
[
-
7.58357372e-01
,
4.20635998e-01
,
-
4.04739919e-02
,
1.59924145e-01
,
2.05135748e00
,
-
1.15997978e00
,
5.37166397e-01
,
2.62873606e-01
,
],
[
-
1.69438001e00
,
4.17574660e-01
,
-
1.49196962e00
,
-
1.76483717e00
,
-
1.94566312e-01
,
-
1.71183858e00
,
7.72903565e-01
,
-
1.11557056e00
,
],
[
5.44028163e-01
,
2.05466114e-01
,
-
3.63045868e-01
,
2.41865062e-01
,
3.20348382e-01
,
-
9.05611176e-01
,
-
1.92690727e-01
,
-
1.19917547e00
,
],
]],
dtype
=
tf
.
float32
,
)
def
test_diagonalize
(
self
):
hidden_states
=
self
.
_get_hidden_states
()
hidden_states
=
tf
.
reshape
(
hidden_states
,
(
1
,
8
,
4
))
# set seq length = 8, hidden dim = 4
chunked_hidden_states
=
longformer_attention
.
LongformerAttention
.
_chunk
(
hidden_states
,
window_overlap
=
2
)
window_overlap_size
=
get_shape_list
(
chunked_hidden_states
)[
2
]
self
.
assertEqual
(
window_overlap_size
,
4
)
padded_hidden_states
=
longformer_attention
.
LongformerAttention
.
_pad_and_diagonalize
(
chunked_hidden_states
)
self
.
assertEqual
(
get_shape_list
(
padded_hidden_states
)[
-
1
],
get_shape_list
(
chunked_hidden_states
)[
-
1
]
+
window_overlap_size
-
1
)
# first row => [0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000]
tf
.
debugging
.
assert_near
(
padded_hidden_states
[
0
,
0
,
0
,
:
4
],
chunked_hidden_states
[
0
,
0
,
0
],
rtol
=
1e-3
)
tf
.
debugging
.
assert_near
(
padded_hidden_states
[
0
,
0
,
0
,
4
:],
tf
.
zeros
((
3
,),
dtype
=
tf
.
dtypes
.
float32
),
rtol
=
1e-3
)
# last row => [0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629]
tf
.
debugging
.
assert_near
(
padded_hidden_states
[
0
,
0
,
-
1
,
3
:],
chunked_hidden_states
[
0
,
0
,
-
1
],
rtol
=
1e-3
)
tf
.
debugging
.
assert_near
(
padded_hidden_states
[
0
,
0
,
-
1
,
:
3
],
tf
.
zeros
((
3
,),
dtype
=
tf
.
dtypes
.
float32
),
rtol
=
1e-3
)
def
test_pad_and_transpose_last_two_dims
(
self
):
hidden_states
=
self
.
_get_hidden_states
()
self
.
assertTrue
(
get_shape_list
(
hidden_states
),
[
1
,
8
,
4
])
# pad along seq length dim
paddings
=
tf
.
constant
([[
0
,
0
],
[
0
,
0
],
[
0
,
1
],
[
0
,
0
]],
dtype
=
tf
.
dtypes
.
int32
)
hidden_states
=
longformer_attention
.
LongformerAttention
.
_chunk
(
hidden_states
,
window_overlap
=
2
)
padded_hidden_states
=
longformer_attention
.
LongformerAttention
.
_pad_and_transpose_last_two_dims
(
hidden_states
,
paddings
)
self
.
assertEqual
(
get_shape_list
(
padded_hidden_states
),
[
1
,
1
,
8
,
5
])
expected_added_dim
=
tf
.
zeros
((
5
,),
dtype
=
tf
.
dtypes
.
float32
)
tf
.
debugging
.
assert_near
(
expected_added_dim
,
padded_hidden_states
[
0
,
0
,
-
1
,
:],
rtol
=
1e-6
)
tf
.
debugging
.
assert_near
(
hidden_states
[
0
,
0
,
-
1
,
:],
tf
.
reshape
(
padded_hidden_states
,
(
1
,
-
1
))[
0
,
24
:
32
],
rtol
=
1e-6
)
def
test_mask_invalid_locations
(
self
):
hidden_states
=
self
.
_get_hidden_states
()
batch_size
=
1
seq_length
=
8
hidden_size
=
4
hidden_states
=
tf
.
reshape
(
hidden_states
,
(
batch_size
,
seq_length
,
hidden_size
))
hidden_states
=
longformer_attention
.
LongformerAttention
.
_chunk
(
hidden_states
,
window_overlap
=
2
)
hid_states_1
=
longformer_attention
.
LongformerAttention
.
_mask_invalid_locations
(
hidden_states
,
1
)
hid_states_2
=
longformer_attention
.
LongformerAttention
.
_mask_invalid_locations
(
hidden_states
,
2
)
hid_states_3
=
longformer_attention
.
LongformerAttention
.
_mask_invalid_locations
(
hidden_states
[:,
:,
:,
:
3
],
2
)
hid_states_4
=
longformer_attention
.
LongformerAttention
.
_mask_invalid_locations
(
hidden_states
[:,
:,
2
:,
:],
2
)
self
.
assertEqual
(
tf
.
math
.
reduce_sum
(
tf
.
cast
(
tf
.
math
.
is_inf
(
hid_states_1
),
tf
.
dtypes
.
int32
)),
8
)
self
.
assertEqual
(
tf
.
math
.
reduce_sum
(
tf
.
cast
(
tf
.
math
.
is_inf
(
hid_states_2
),
tf
.
dtypes
.
int32
)),
24
)
self
.
assertEqual
(
tf
.
math
.
reduce_sum
(
tf
.
cast
(
tf
.
math
.
is_inf
(
hid_states_3
),
tf
.
dtypes
.
int32
)),
24
)
self
.
assertEqual
(
tf
.
math
.
reduce_sum
(
tf
.
cast
(
tf
.
math
.
is_inf
(
hid_states_4
),
tf
.
dtypes
.
int32
)),
12
)
def
test_chunk
(
self
):
hidden_states
=
self
.
_get_hidden_states
()
batch_size
=
1
seq_length
=
8
hidden_size
=
4
hidden_states
=
tf
.
reshape
(
hidden_states
,
(
batch_size
,
seq_length
,
hidden_size
))
chunked_hidden_states
=
longformer_attention
.
LongformerAttention
.
_chunk
(
hidden_states
,
window_overlap
=
2
)
# expected slices across chunk and seq length dim
expected_slice_along_seq_length
=
tf
.
convert_to_tensor
(
[
0.4983
,
-
0.7584
,
-
1.6944
],
dtype
=
tf
.
dtypes
.
float32
)
expected_slice_along_chunk
=
tf
.
convert_to_tensor
(
[
0.4983
,
-
1.8348
,
-
0.7584
,
2.0514
],
dtype
=
tf
.
dtypes
.
float32
)
self
.
assertEqual
(
get_shape_list
(
chunked_hidden_states
),
[
1
,
3
,
4
,
4
])
tf
.
debugging
.
assert_near
(
chunked_hidden_states
[
0
,
:,
0
,
0
],
expected_slice_along_seq_length
,
rtol
=
1e-3
)
tf
.
debugging
.
assert_near
(
chunked_hidden_states
[
0
,
0
,
:,
0
],
expected_slice_along_chunk
,
rtol
=
1e-3
)
def
test_layer_local_attn
(
self
):
hidden_states
=
self
.
_get_hidden_states
()
batch_size
,
seq_length
,
_
=
hidden_states
.
shape
layer
=
longformer_attention
.
LongformerAttention
(
num_heads
=
2
,
key_dim
=
4
,
value_dim
=
4
,
layer_id
=
0
,
attention_window
=
4
,
global_attention_size
=
0
,
)
attention_mask
=
tf
.
zeros
((
batch_size
,
seq_length
),
dtype
=
tf
.
dtypes
.
float32
)
is_index_global_attn
=
tf
.
math
.
greater
(
attention_mask
,
1
)
attention_mask
=
tf
.
where
(
tf
.
range
(
4
)[
None
,
:,
None
,
None
]
>
1
,
-
10000.0
,
attention_mask
[:,
:,
None
,
None
])
is_index_masked
=
tf
.
math
.
less
(
attention_mask
[:,
:,
0
,
0
],
0
)
output_hidden_states
=
layer
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
is_index_masked
=
is_index_masked
,
is_index_global_attn
=
is_index_global_attn
,
)[
0
]
self
.
assertTrue
(
output_hidden_states
.
shape
,
(
1
,
4
,
8
))
def
test_layer_global_attn
(
self
):
layer
=
longformer_attention
.
LongformerAttention
(
num_heads
=
2
,
key_dim
=
4
,
value_dim
=
4
,
layer_id
=
0
,
attention_window
=
4
,
global_attention_size
=
1
,
)
hidden_states
=
self
.
_get_hidden_states
()
hidden_states
=
tf
.
concat
(
[
self
.
_get_hidden_states
(),
self
.
_get_hidden_states
()
-
0.5
],
axis
=
0
)
_
,
seq_length
,
_
=
hidden_states
.
shape
# create attn mask
attention_mask_1
=
tf
.
zeros
((
1
,
1
,
1
,
seq_length
),
dtype
=
tf
.
dtypes
.
float32
)
attention_mask_2
=
tf
.
zeros
((
1
,
1
,
1
,
seq_length
),
dtype
=
tf
.
dtypes
.
float32
)
attention_mask_1
=
tf
.
where
(
tf
.
range
(
4
)[
None
,
:,
None
,
None
]
==
0
,
10000.0
,
attention_mask_1
)
attention_mask_1
=
tf
.
where
(
tf
.
range
(
4
)[
None
,
:,
None
,
None
]
>
2
,
-
10000.0
,
attention_mask_1
)
attention_mask_2
=
tf
.
where
(
tf
.
range
(
4
)[
None
,
:,
None
,
None
]
==
0
,
10000.0
,
attention_mask_2
)
attention_mask
=
tf
.
concat
([
attention_mask_1
,
attention_mask_2
],
axis
=
0
)
is_index_masked
=
tf
.
math
.
less
(
attention_mask
[:,
:,
0
,
0
],
0
)
is_index_global_attn
=
tf
.
math
.
greater
(
attention_mask
[:,
:,
0
,
0
],
0
)
output_hidden_states
=
layer
(
hidden_states
=
hidden_states
,
attention_mask
=-
tf
.
math
.
abs
(
attention_mask
),
is_index_masked
=
is_index_masked
,
is_index_global_attn
=
is_index_global_attn
,
)[
0
]
self
.
assertTrue
(
output_hidden_states
.
shape
,
(
2
,
4
,
8
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/projects/longformer/longformer_encoder.py
0 → 100644
View file @
8b641b13
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Longformer encoder. Modified From huggingface/transformers."""
# pylint: disable=g-classes-have-attributes
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Union
from
absl
import
logging
import
tensorflow
as
tf
from
official.modeling.tf_utils
import
get_shape_list
from
official.nlp.modeling
import
layers
from
official.projects.longformer.longformer_encoder_block
import
LongformerEncoderBlock
_Initializer
=
Union
[
str
,
tf
.
keras
.
initializers
.
Initializer
]
_approx_gelu
=
lambda
x
:
tf
.
keras
.
activations
.
gelu
(
x
,
approximate
=
True
)
class
LongformerEncoder
(
tf
.
keras
.
layers
.
Layer
):
"""LongformerEncoder.
Args:
vocab_size: The size of the token vocabulary.
attention_window: list of ints representing the window size for each layer.
global_attention_size: the size of global attention used for each token.
pad_token_id: the token id for the pad token
hidden_size: The size of the transformer hidden layers.
num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads.
max_sequence_length: The maximum sequence length that this encoder can
consume. If None, max_sequence_length uses the value from sequence length.
This determines the variable shape for positional embeddings.
type_vocab_size: The number of types that the 'type_ids' input can take.
inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network for each transformer.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network for each transformer.
output_dropout: Dropout probability for the post-attention and output
dropout.
attention_dropout: The dropout rate to use for the attention layers within
the transformer layers.
initializer: The initialzer to use for all weights in this encoder.
output_range: The sequence output range, [0, output_range), by slicing the
target sequence of the last transformer layer. `None` means the entire
target sequence will attend to the source sequence, which yields the full
output.
embedding_width: The width of the word embeddings. If the embedding width is
not equal to hidden size, embedding parameters will be factorized into two
matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size').
embedding_layer: An optional Layer instance which will be called to generate
embeddings for the input word IDs.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
"""
def
__init__
(
self
,
vocab_size
:
int
,
attention_window
:
Union
[
List
[
int
],
int
]
=
512
,
global_attention_size
:
int
=
0
,
pad_token_id
:
int
=
1
,
hidden_size
:
int
=
768
,
num_layers
:
int
=
12
,
num_attention_heads
:
int
=
12
,
max_sequence_length
:
int
=
512
,
type_vocab_size
:
int
=
16
,
inner_dim
:
int
=
3072
,
inner_activation
:
Callable
[...,
Any
]
=
_approx_gelu
,
output_dropout
:
float
=
0.1
,
attention_dropout
:
float
=
0.1
,
initializer
:
_Initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
output_range
:
Optional
[
int
]
=
None
,
embedding_width
:
Optional
[
int
]
=
None
,
embedding_layer
:
Optional
[
tf
.
keras
.
layers
.
Layer
]
=
None
,
norm_first
:
bool
=
False
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
# Longformer args
self
.
_attention_window
=
attention_window
self
.
_global_attention_size
=
global_attention_size
self
.
_pad_token_id
=
pad_token_id
activation
=
tf
.
keras
.
activations
.
get
(
inner_activation
)
initializer
=
tf
.
keras
.
initializers
.
get
(
initializer
)
if
embedding_width
is
None
:
embedding_width
=
hidden_size
if
embedding_layer
is
None
:
self
.
_embedding_layer
=
layers
.
OnDeviceEmbedding
(
vocab_size
=
vocab_size
,
embedding_width
=
embedding_width
,
initializer
=
initializer
,
name
=
'word_embeddings'
)
else
:
self
.
_embedding_layer
=
embedding_layer
self
.
_position_embedding_layer
=
layers
.
PositionEmbedding
(
initializer
=
initializer
,
max_length
=
max_sequence_length
,
name
=
'position_embedding'
)
self
.
_type_embedding_layer
=
layers
.
OnDeviceEmbedding
(
vocab_size
=
type_vocab_size
,
embedding_width
=
embedding_width
,
initializer
=
initializer
,
use_one_hot
=
True
,
name
=
'type_embeddings'
)
self
.
_embedding_norm_layer
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
'embeddings/layer_norm'
,
axis
=-
1
,
epsilon
=
1e-12
,
dtype
=
tf
.
float32
)
self
.
_embedding_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
output_dropout
,
name
=
'embedding_dropout'
)
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
self
.
_embedding_projection
=
None
if
embedding_width
!=
hidden_size
:
self
.
_embedding_projection
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
'...x,xy->...y'
,
output_shape
=
hidden_size
,
bias_axes
=
'y'
,
kernel_initializer
=
initializer
,
name
=
'embedding_projection'
)
self
.
_transformer_layers
=
[]
self
.
_attention_mask_layer
=
layers
.
SelfAttentionMask
(
name
=
'self_attention_mask'
)
for
i
in
range
(
num_layers
):
layer
=
LongformerEncoderBlock
(
global_attention_size
=
global_attention_size
,
num_attention_heads
=
num_attention_heads
,
inner_dim
=
inner_dim
,
inner_activation
=
inner_activation
,
attention_window
=
attention_window
[
i
],
layer_id
=
i
,
output_dropout
=
output_dropout
,
attention_dropout
=
attention_dropout
,
norm_first
=
norm_first
,
output_range
=
output_range
if
i
==
num_layers
-
1
else
None
,
kernel_initializer
=
initializer
,
name
=
f
'transformer/layer_
{
i
}
'
)
self
.
_transformer_layers
.
append
(
layer
)
self
.
_pooler_layer
=
tf
.
keras
.
layers
.
Dense
(
units
=
hidden_size
,
activation
=
'tanh'
,
kernel_initializer
=
initializer
,
name
=
'pooler_transform'
)
self
.
_config
=
{
'vocab_size'
:
vocab_size
,
'hidden_size'
:
hidden_size
,
'num_layers'
:
num_layers
,
'num_attention_heads'
:
num_attention_heads
,
'max_sequence_length'
:
max_sequence_length
,
'type_vocab_size'
:
type_vocab_size
,
'inner_dim'
:
inner_dim
,
'inner_activation'
:
tf
.
keras
.
activations
.
serialize
(
activation
),
'output_dropout'
:
output_dropout
,
'attention_dropout'
:
attention_dropout
,
'initializer'
:
tf
.
keras
.
initializers
.
serialize
(
initializer
),
'output_range'
:
output_range
,
'embedding_width'
:
embedding_width
,
'embedding_layer'
:
embedding_layer
,
'norm_first'
:
norm_first
,
'attention_window'
:
attention_window
,
'global_attention_size'
:
global_attention_size
,
'pad_token_id'
:
pad_token_id
,
}
self
.
inputs
=
dict
(
input_word_ids
=
tf
.
keras
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
),
input_mask
=
tf
.
keras
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
),
input_type_ids
=
tf
.
keras
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
))
def
call
(
self
,
inputs
):
word_embeddings
=
None
if
isinstance
(
inputs
,
dict
):
word_ids
=
inputs
.
get
(
'input_word_ids'
)
# input_ids
mask
=
inputs
.
get
(
'input_mask'
)
# attention_mask
type_ids
=
inputs
.
get
(
'input_type_ids'
)
# token_type_ids
word_embeddings
=
inputs
.
get
(
'input_word_embeddings'
,
None
)
# input_embeds
else
:
raise
ValueError
(
f
'Unexpected inputs type to
{
self
.
__class__
}
.'
)
(
padding_len
,
word_ids
,
mask
,
type_ids
,
word_embeddings
,
)
=
self
.
_pad_to_window_size
(
word_ids
=
word_ids
,
mask
=
mask
,
type_ids
=
type_ids
,
word_embeddings
=
word_embeddings
,
pad_token_id
=
self
.
_pad_token_id
)
if
word_embeddings
is
None
:
word_embeddings
=
self
.
_embedding_layer
(
word_ids
)
# absolute position embeddings.
position_embeddings
=
self
.
_position_embedding_layer
(
word_embeddings
)
type_embeddings
=
self
.
_type_embedding_layer
(
type_ids
)
embeddings
=
word_embeddings
+
position_embeddings
+
type_embeddings
embeddings
=
self
.
_embedding_norm_layer
(
embeddings
)
embeddings
=
self
.
_embedding_dropout
(
embeddings
)
if
self
.
_embedding_projection
is
not
None
:
embeddings
=
self
.
_embedding_projection
(
embeddings
)
batch_size
,
seq_len
=
get_shape_list
(
mask
)
# create masks with fixed len global_attention_size
mask
=
tf
.
transpose
(
tf
.
concat
(
values
=
[
tf
.
ones
(
(
self
.
_global_attention_size
,
batch_size
),
tf
.
int32
)
*
2
,
tf
.
transpose
(
mask
)[
self
.
_global_attention_size
:]
],
axis
=
0
))
is_index_masked
=
tf
.
math
.
less
(
mask
,
1
)
is_index_global_attn
=
tf
.
transpose
(
tf
.
concat
(
values
=
[
tf
.
ones
((
self
.
_global_attention_size
,
batch_size
),
tf
.
bool
),
tf
.
zeros
((
seq_len
-
self
.
_global_attention_size
,
batch_size
),
tf
.
bool
)
],
axis
=
0
))
# Longformer
attention_mask
=
mask
extended_attention_mask
=
tf
.
reshape
(
attention_mask
,
(
tf
.
shape
(
mask
)[
0
],
tf
.
shape
(
mask
)[
1
],
1
,
1
))
attention_mask
=
tf
.
cast
(
tf
.
math
.
abs
(
1
-
extended_attention_mask
),
tf
.
dtypes
.
float32
)
*
-
10000.0
encoder_outputs
=
[]
x
=
embeddings
# TFLongformerEncoder
for
layer
in
self
.
_transformer_layers
:
x
=
layer
([
x
,
attention_mask
,
is_index_masked
,
is_index_global_attn
])
encoder_outputs
.
append
(
x
)
last_encoder_output
=
encoder_outputs
[
-
1
]
if
padding_len
>
0
:
last_encoder_output
=
last_encoder_output
[:,
:
-
padding_len
]
first_token_tensor
=
last_encoder_output
[:,
0
,
:]
pooled_output
=
self
.
_pooler_layer
(
first_token_tensor
)
return
dict
(
sequence_output
=
last_encoder_output
,
pooled_output
=
pooled_output
,
encoder_outputs
=
encoder_outputs
)
def
get_embedding_table
(
self
):
return
self
.
_embedding_layer
.
embeddings
def
get_embedding_layer
(
self
):
return
self
.
_embedding_layer
def
get_config
(
self
):
return
dict
(
self
.
_config
)
@
property
def
transformer_layers
(
self
):
"""List of Transformer layers in the encoder."""
return
self
.
_transformer_layers
@
property
def
pooler_layer
(
self
):
"""The pooler dense layer after the transformer layers."""
return
self
.
_pooler_layer
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
if
'embedding_layer'
in
config
and
config
[
'embedding_layer'
]
is
not
None
:
warn_string
=
(
'You are reloading a model that was saved with a '
'potentially-shared embedding layer object. If you contine to '
'train this model, the embedding layer will no longer be shared. '
'To work around this, load the model outside of the Keras API.'
)
print
(
'WARNING: '
+
warn_string
)
logging
.
warn
(
warn_string
)
return
cls
(
**
config
)
def
_pad_to_window_size
(
self
,
word_ids
,
mask
,
type_ids
,
word_embeddings
,
pad_token_id
,
):
# padding
attention_window
=
max
(
self
.
_attention_window
)
assert
(
attention_window
%
2
==
0
),
(
'`attention_window` should be an even value.'
f
'Given
{
attention_window
}
'
)
input_shape
=
get_shape_list
(
word_ids
)
if
word_ids
is
not
None
else
get_shape_list
(
word_embeddings
)
batch_size
,
seq_len
=
input_shape
[:
2
]
if
seq_len
is
not
None
:
padding_len
=
(
attention_window
-
seq_len
%
attention_window
)
%
attention_window
else
:
padding_len
=
0
paddings
=
tf
.
convert_to_tensor
([[
0
,
0
],
[
0
,
padding_len
]])
if
word_ids
is
not
None
:
word_ids
=
tf
.
pad
(
word_ids
,
paddings
,
constant_values
=
pad_token_id
)
if
word_embeddings
is
not
None
:
def
pad_embeddings
():
word_ids_padding
=
tf
.
fill
((
batch_size
,
padding_len
),
self
.
pad_token_id
)
word_embeddings_padding
=
self
.
_embedding_layer
(
word_ids_padding
)
return
tf
.
concat
([
word_embeddings
,
word_embeddings_padding
],
axis
=-
2
)
word_embeddings
=
tf
.
cond
(
tf
.
math
.
greater
(
padding_len
,
0
),
pad_embeddings
,
lambda
:
word_embeddings
)
mask
=
tf
.
pad
(
mask
,
paddings
,
constant_values
=
False
)
# no attention on the padding tokens
token_type_ids
=
tf
.
pad
(
type_ids
,
paddings
,
constant_values
=
0
)
# pad with token_type_id = 0
return
(
padding_len
,
word_ids
,
mask
,
token_type_ids
,
word_embeddings
,
)
official/projects/longformer/longformer_encoder_block.py
0 → 100644
View file @
8b641b13
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Longformer attention layer. Modified From huggingface/transformers."""
import
tensorflow
as
tf
from
official.projects.longformer.longformer_attention
import
LongformerAttention
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
class
LongformerEncoderBlock
(
tf
.
keras
.
layers
.
Layer
):
"""LongformerEncoderBlock.
Args:
num_attention_heads: Number of attention heads.
inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network.
output_range: the sequence output range, [0, output_range) for slicing the
target sequence. `None` means the target sequence is not sliced.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
output_dropout: Dropout probability for the post-attention and output
dropout.
attention_dropout: Dropout probability for within the attention layer.
inner_dropout: Dropout probability for the first Dense layer in a
two-layer feedforward network.
attention_initializer: Initializer for kernels of attention layers. If set
`None`, attention layers use kernel_initializer as initializer for
kernel.
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
**kwargs: keyword arguments/
"""
def
__init__
(
self
,
global_attention_size
,
num_attention_heads
,
inner_dim
,
inner_activation
,
# Longformer
attention_window
,
layer_id
=
0
,
output_range
=
None
,
kernel_initializer
=
"glorot_uniform"
,
bias_initializer
=
"zeros"
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
activity_regularizer
=
None
,
kernel_constraint
=
None
,
bias_constraint
=
None
,
use_bias
=
True
,
norm_first
=
False
,
norm_epsilon
=
1e-12
,
output_dropout
=
0.0
,
attention_dropout
=
0.0
,
inner_dropout
=
0.0
,
attention_initializer
=
None
,
attention_axes
=
None
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
global_attention_size
=
global_attention_size
self
.
_num_heads
=
num_attention_heads
self
.
_inner_dim
=
inner_dim
self
.
_inner_activation
=
inner_activation
# Longformer
self
.
_attention_window
=
attention_window
self
.
_layer_id
=
layer_id
self
.
_attention_dropout
=
attention_dropout
self
.
_attention_dropout_rate
=
attention_dropout
self
.
_output_dropout
=
output_dropout
self
.
_output_dropout_rate
=
output_dropout
self
.
_output_range
=
output_range
self
.
_kernel_initializer
=
tf
.
keras
.
initializers
.
get
(
kernel_initializer
)
self
.
_bias_initializer
=
tf
.
keras
.
initializers
.
get
(
bias_initializer
)
self
.
_kernel_regularizer
=
tf
.
keras
.
regularizers
.
get
(
kernel_regularizer
)
self
.
_bias_regularizer
=
tf
.
keras
.
regularizers
.
get
(
bias_regularizer
)
self
.
_activity_regularizer
=
tf
.
keras
.
regularizers
.
get
(
activity_regularizer
)
self
.
_kernel_constraint
=
tf
.
keras
.
constraints
.
get
(
kernel_constraint
)
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
self
.
_use_bias
=
use_bias
self
.
_norm_first
=
norm_first
self
.
_norm_epsilon
=
norm_epsilon
self
.
_inner_dropout
=
inner_dropout
if
attention_initializer
:
self
.
_attention_initializer
=
tf
.
keras
.
initializers
.
get
(
attention_initializer
)
else
:
self
.
_attention_initializer
=
self
.
_kernel_initializer
self
.
_attention_axes
=
attention_axes
def
build
(
self
,
input_shape
):
if
isinstance
(
input_shape
,
tf
.
TensorShape
):
input_tensor_shape
=
input_shape
elif
isinstance
(
input_shape
,
(
list
,
tuple
)):
input_tensor_shape
=
tf
.
TensorShape
(
input_shape
[
0
])
else
:
raise
ValueError
(
f
"The type of input shape argument is not supported, got: "
f
"
{
type
(
input_shape
)
}
"
)
einsum_equation
=
"abc,cd->abd"
if
len
(
input_tensor_shape
.
as_list
())
>
3
:
einsum_equation
=
"...bc,cd->...bd"
hidden_size
=
input_tensor_shape
[
-
1
]
if
hidden_size
%
self
.
_num_heads
!=
0
:
raise
ValueError
(
f
"The input size (
{
hidden_size
}
) is not a multiple of the number of attention "
f
"heads (
{
self
.
_num_heads
}
)"
)
self
.
_attention_head_size
=
int
(
hidden_size
//
self
.
_num_heads
)
common_kwargs
=
dict
(
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
)
# TFLongformerSelfAttention + TFLongformerSelfOutput.dense
self
.
_attention_layer
=
LongformerAttention
(
# Longformer
layer_id
=
self
.
_layer_id
,
global_attention_size
=
self
.
global_attention_size
,
attention_window
=
self
.
_attention_window
,
num_heads
=
self
.
_num_heads
,
key_dim
=
self
.
_attention_head_size
,
dropout
=
self
.
_attention_dropout
,
use_bias
=
self
.
_use_bias
,
kernel_initializer
=
self
.
_attention_initializer
,
attention_axes
=
self
.
_attention_axes
,
name
=
"self_attention"
,
**
common_kwargs
)
# TFLongformerSelfOutput.dropout
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_output_dropout
)
# Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet.
# TFLongformerSelfOutput.Layernorm
self
.
_attention_layer_norm
=
(
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"self_attention_layer_norm"
,
axis
=-
1
,
epsilon
=
self
.
_norm_epsilon
,
dtype
=
tf
.
float32
))
# TFLongformerIntermediate
# TFLongformerIntermediate.dense
self
.
_intermediate_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
einsum_equation
,
output_shape
=
(
None
,
self
.
_inner_dim
),
bias_axes
=
"d"
,
kernel_initializer
=
self
.
_kernel_initializer
,
name
=
"intermediate"
,
**
common_kwargs
)
policy
=
tf
.
keras
.
mixed_precision
.
global_policy
()
if
policy
.
name
==
"mixed_bfloat16"
:
# bfloat16 causes BERT with the LAMB optimizer to not converge
# as well, so we use float32.
# TODO(b/154538392): Investigate this.
policy
=
tf
.
float32
# TFLongformerIntermediate.intermediate_act_fn
self
.
_intermediate_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
self
.
_inner_activation
,
dtype
=
policy
)
self
.
_inner_dropout_layer
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_inner_dropout
)
# TFLongformerOutput
# TFLongformerOutput.dense
self
.
_output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
einsum_equation
,
output_shape
=
(
None
,
hidden_size
),
bias_axes
=
"d"
,
name
=
"output"
,
kernel_initializer
=
self
.
_kernel_initializer
,
**
common_kwargs
)
# TFLongformerOutput.dropout
self
.
_output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_output_dropout
)
# Use float32 in layernorm for numeric stability.
# TFLongformerOutput.layernorm
self
.
_output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"output_layer_norm"
,
axis
=-
1
,
epsilon
=
self
.
_norm_epsilon
,
dtype
=
tf
.
float32
)
super
().
build
(
input_shape
)
def
get_config
(
self
):
config
=
{
"num_attention_heads"
:
self
.
_num_heads
,
"inner_dim"
:
self
.
_inner_dim
,
"inner_activation"
:
self
.
_inner_activation
,
"output_dropout"
:
self
.
_output_dropout_rate
,
"attention_dropout"
:
self
.
_attention_dropout_rate
,
"output_range"
:
self
.
_output_range
,
"kernel_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_kernel_initializer
),
"bias_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_bias_initializer
),
"kernel_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_kernel_regularizer
),
"bias_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_bias_regularizer
),
"activity_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_activity_regularizer
),
"kernel_constraint"
:
tf
.
keras
.
constraints
.
serialize
(
self
.
_kernel_constraint
),
"bias_constraint"
:
tf
.
keras
.
constraints
.
serialize
(
self
.
_bias_constraint
),
"use_bias"
:
self
.
_use_bias
,
"norm_first"
:
self
.
_norm_first
,
"norm_epsilon"
:
self
.
_norm_epsilon
,
"inner_dropout"
:
self
.
_inner_dropout
,
"attention_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_attention_initializer
),
"attention_axes"
:
self
.
_attention_axes
,
}
base_config
=
super
().
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
inputs
):
"""Transformer self-attention encoder block call.
Args:
inputs: a single tensor or a list of tensors. `input tensor` as the single
sequence of embeddings. [`input tensor`, `attention mask`] to have the
additional attention mask. [`query tensor`, `key value tensor`,
`attention mask`] to have separate input streams for the query, and
key/value to the multi-head attention.
Returns:
An output tensor with the same dimensions as input/query tensor.
"""
if
isinstance
(
inputs
,
(
list
,
tuple
)):
if
len
(
inputs
)
==
4
:
(
input_tensor
,
attention_mask
,
is_index_masked
,
is_index_global_attn
,
)
=
inputs
key_value
=
None
elif
len
(
inputs
)
==
5
:
assert
False
# No key_value
else
:
raise
ValueError
(
f
"Unexpected inputs to
{
self
.
__class__
}
with length at
{
len
(
inputs
)
}
"
)
else
:
input_tensor
=
inputs
attention_mask
=
None
is_index_masked
=
None
is_index_global_attn
=
None
key_value
=
None
if
self
.
_output_range
:
if
self
.
_norm_first
:
source_tensor
=
input_tensor
[:,
0
:
self
.
_output_range
,
:]
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
)
if
key_value
is
not
None
:
key_value
=
self
.
_attention_layer_norm
(
key_value
)
target_tensor
=
input_tensor
[:,
0
:
self
.
_output_range
,
:]
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
0
:
self
.
_output_range
,
:]
if
is_index_masked
is
not
None
:
is_index_masked
=
is_index_masked
[:,
0
:
self
.
_output_range
]
if
is_index_global_attn
is
not
None
:
is_index_global_attn
=
is_index_global_attn
[:,
0
:
self
.
_output_range
]
else
:
if
self
.
_norm_first
:
source_tensor
=
input_tensor
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
)
if
key_value
is
not
None
:
key_value
=
self
.
_attention_layer_norm
(
key_value
)
target_tensor
=
input_tensor
if
key_value
is
None
:
key_value
=
input_tensor
attention_output
=
self
.
_attention_layer
(
hidden_states
=
target_tensor
,
attention_mask
=
attention_mask
,
is_index_masked
=
is_index_masked
,
is_index_global_attn
=
is_index_global_attn
,
)
# TFLongformerAttention.TFLongformerSelfOutput.* - {.dense}
attention_output
=
self
.
_attention_dropout
(
attention_output
)
if
self
.
_norm_first
:
attention_output
=
source_tensor
+
attention_output
else
:
attention_output
=
self
.
_attention_layer_norm
(
target_tensor
+
attention_output
)
if
self
.
_norm_first
:
source_attention_output
=
attention_output
attention_output
=
self
.
_output_layer_norm
(
attention_output
)
# TFLongformerIntermediate
inner_output
=
self
.
_intermediate_dense
(
attention_output
)
inner_output
=
self
.
_intermediate_activation_layer
(
inner_output
)
inner_output
=
self
.
_inner_dropout_layer
(
inner_output
)
# TFLongformerOutput
layer_output
=
self
.
_output_dense
(
inner_output
)
layer_output
=
self
.
_output_dropout
(
layer_output
)
if
self
.
_norm_first
:
return
source_attention_output
+
layer_output
# During mixed precision training, layer norm output is always fp32 for now.
# Casts fp32 for the subsequent add.
layer_output
=
tf
.
cast
(
layer_output
,
tf
.
float32
)
return
self
.
_output_layer_norm
(
layer_output
+
attention_output
)
official/projects/longformer/longformer_encoder_test.py
0 → 100644
View file @
8b641b13
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.projects.longformer.longformer_encoder."""
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
official.projects.longformer.longformer_encoder
import
LongformerEncoder
class
LongformerEncoderTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
(
LongformerEncoderTest
,
self
).
setUp
()
np
.
random
.
seed
(
0
)
tf
.
random
.
set_seed
(
0
)
@
combinations
.
generate
(
combinations
.
combine
(
attention_window
=
[
32
,
128
],
global_attention_size
=
[
0
,
1
,
2
]))
def
test_encoder
(
self
,
attention_window
,
global_attention_size
):
sequence_length
=
128
batch_size
=
2
vocab_size
=
1024
hidden_size
=
256
network
=
LongformerEncoder
(
global_attention_size
=
global_attention_size
,
vocab_size
=
vocab_size
,
attention_window
=
[
attention_window
],
hidden_size
=
hidden_size
,
num_layers
=
1
,
num_attention_heads
=
4
,
max_sequence_length
=
512
)
word_id_data
=
np
.
random
.
randint
(
vocab_size
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
type_id_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
inputs
=
{
'input_word_ids'
:
word_id_data
,
'input_mask'
:
mask_data
,
'input_type_ids'
:
type_id_data
,
}
outputs
=
network
(
inputs
)
self
.
assertEqual
(
outputs
[
'sequence_output'
].
shape
,
(
batch_size
,
sequence_length
,
hidden_size
))
@
combinations
.
generate
(
combinations
.
combine
(
norm_first
=
[
True
,
False
],
global_attention_size
=
[
0
,
1
,
2
]))
def
test_norm_first
(
self
,
norm_first
,
global_attention_size
):
sequence_length
=
128
batch_size
=
2
vocab_size
=
1024
hidden_size
=
256
network
=
LongformerEncoder
(
global_attention_size
=
global_attention_size
,
vocab_size
=
vocab_size
,
attention_window
=
[
32
],
hidden_size
=
hidden_size
,
num_layers
=
1
,
num_attention_heads
=
4
,
max_sequence_length
=
512
,
norm_first
=
norm_first
)
word_id_data
=
np
.
random
.
randint
(
vocab_size
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
type_id_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
inputs
=
{
'input_word_ids'
:
word_id_data
,
'input_mask'
:
mask_data
,
'input_type_ids'
:
type_id_data
,
}
outputs
=
network
(
inputs
)
self
.
assertEqual
(
outputs
[
'sequence_output'
].
shape
,
(
batch_size
,
sequence_length
,
hidden_size
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/projects/longformer/longformer_experiments.py
0 → 100644
View file @
8b641b13
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Longformer experiments."""
# pylint: disable=g-doc-return-or-yield,line-too-long
import
dataclasses
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.modeling
import
optimization
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.data
import
pretrain_dataloader
from
official.nlp.data
import
sentence_prediction_dataloader
from
official.nlp.tasks
import
masked_lm
from
official.nlp.tasks
import
sentence_prediction
from
official.projects.longformer.longformer
import
LongformerEncoderConfig
AdamWeightDecay
=
optimization
.
AdamWeightDecayConfig
PolynomialLr
=
optimization
.
PolynomialLrConfig
PolynomialWarmupConfig
=
optimization
.
PolynomialWarmupConfig
@
dataclasses
.
dataclass
class
LongformerOptimizationConfig
(
optimization
.
OptimizationConfig
):
"""Longformer optimization configuration."""
optimizer
:
optimization
.
OptimizerConfig
=
optimization
.
OptimizerConfig
(
type
=
'adamw'
,
adamw
=
AdamWeightDecay
(
weight_decay_rate
=
0.01
,
exclude_from_weight_decay
=
[
'LayerNorm'
,
'layer_norm'
,
'bias'
],
epsilon
=
1e-6
))
learning_rate
:
optimization
.
LrConfig
=
optimization
.
LrConfig
(
type
=
'polynomial'
,
polynomial
=
PolynomialLr
(
initial_learning_rate
=
1e-4
,
decay_steps
=
1000000
,
end_learning_rate
=
0.0
))
warmup
:
optimization
.
WarmupConfig
=
optimization
.
WarmupConfig
(
type
=
'polynomial'
,
polynomial
=
PolynomialWarmupConfig
(
warmup_steps
=
10000
))
@
exp_factory
.
register_config_factory
(
'longformer/pretraining'
)
def
longformer_pretraining
()
->
cfg
.
ExperimentConfig
:
"""Longformer pretraining experiment."""
config
=
cfg
.
ExperimentConfig
(
runtime
=
cfg
.
RuntimeConfig
(
enable_xla
=
True
),
task
=
masked_lm
.
MaskedLMConfig
(
model
=
bert
.
PretrainerConfig
(
encoder
=
encoders
.
EncoderConfig
(
type
=
'any'
,
any
=
LongformerEncoderConfig
()),
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
768
,
num_classes
=
2
,
dropout_rate
=
0.1
,
name
=
'next_sentence'
)
]),
train_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
use_v2_feature_names
=
True
),
validation_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
use_v2_feature_names
=
True
,
is_training
=
False
)),
trainer
=
cfg
.
TrainerConfig
(
optimizer_config
=
LongformerOptimizationConfig
(),
train_steps
=
1000000
),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
@
exp_factory
.
register_config_factory
(
'longformer/glue'
)
def
longformer_glue
()
->
cfg
.
ExperimentConfig
:
"""Longformer glue fine-tuning."""
config
=
cfg
.
ExperimentConfig
(
task
=
sentence_prediction
.
SentencePredictionConfig
(
model
=
sentence_prediction
.
ModelConfig
(
encoder
=
encoders
.
EncoderConfig
(
type
=
'any'
,
any
=
LongformerEncoderConfig
())),
train_data
=
sentence_prediction_dataloader
.
SentencePredictionDataConfig
(),
validation_data
=
sentence_prediction_dataloader
.
SentencePredictionDataConfig
(
is_training
=
False
,
drop_remainder
=
False
)),
trainer
=
cfg
.
TrainerConfig
(
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
'adamw'
:
{
'weight_decay_rate'
:
0.01
,
'exclude_from_weight_decay'
:
[
'LayerNorm'
,
'layer_norm'
,
'bias'
],
}
},
'learning_rate'
:
{
'type'
:
'polynomial'
,
'polynomial'
:
{
'initial_learning_rate'
:
3e-5
,
'end_learning_rate'
:
0.0
,
}
},
'warmup'
:
{
'type'
:
'polynomial'
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
official/
vision/beta
/train.py
→
official/
projects/longformer
/train.py
View file @
8b641b13
...
...
@@ -12,22 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""TensorFlow Model Garden Vision training driver."""
"""A customized training library for the specific task."""
from
absl
import
app
from
absl
import
flags
import
gin
# pylint: disable=unused-import
from
official.common
import
registry_imports
# pylint: enable=unused-import
from
official.common
import
distribute_utils
from
official.common
import
flags
as
tfm_flags
from
official.core
import
task_factory
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.modeling
import
performance
from
official.projects.longformer
import
longformer_experiments
# pylint: disable=unused-import
FLAGS
=
flags
.
FLAGS
...
...
@@ -51,7 +48,9 @@ def main(_):
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
num_gpus
=
params
.
runtime
.
num_gpus
,
tpu_address
=
params
.
runtime
.
tpu
)
tpu_address
=
params
.
runtime
.
tpu
,
**
params
.
runtime
.
model_parallelism
())
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
...
...
@@ -64,7 +63,7 @@ def main(_):
train_utils
.
save_gin_config
(
FLAGS
.
mode
,
model_dir
)
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
flags
.
mark_flags_as_required
([
'experiment'
,
'mode'
,
'model_dir'
])
app
.
run
(
main
)
official/projects/longformer/utils/convert_pretrained_pytorch_checkpoint_to_tf.py
0 → 100644
View file @
8b641b13
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Converts pre-trained pytorch checkpoint into a tf encoder checkpoint."""
import
os
from
absl
import
app
import
numpy
as
np
import
tensorflow
as
tf
import
transformers
from
official.modeling
import
tf_utils
from
official.projects.longformer.longformer
import
LongformerEncoderConfig
from
official.projects.longformer.longformer_encoder
import
LongformerEncoder
def
_get_pytorch_longformer_model
():
pretrained_lm
=
"allenai/longformer-base-4096"
model
=
transformers
.
AutoModel
.
from_pretrained
(
pretrained_lm
)
return
{
n
:
p
.
data
.
numpy
()
for
n
,
p
in
model
.
named_parameters
()}
def
_create_longformer_model
():
"""Creates a Longformer model."""
encoder_cfg
=
LongformerEncoderConfig
encoder_cfg
.
vocab_size
=
50265
encoder_cfg
.
max_position_embeddings
=
4098
encoder_cfg
.
attention_window
=
[
2
]
*
encoder_cfg
.
num_layers
encoder_cfg
.
global_attention_size
=
1
encoder
=
LongformerEncoder
(
attention_window
=
encoder_cfg
.
attention_window
,
global_attention_size
=
encoder_cfg
.
global_attention_size
,
vocab_size
=
encoder_cfg
.
vocab_size
,
hidden_size
=
encoder_cfg
.
hidden_size
,
num_layers
=
encoder_cfg
.
num_layers
,
num_attention_heads
=
encoder_cfg
.
num_attention_heads
,
inner_dim
=
encoder_cfg
.
intermediate_size
,
inner_activation
=
tf_utils
.
get_activation
(
encoder_cfg
.
hidden_activation
),
output_dropout
=
encoder_cfg
.
dropout_rate
,
attention_dropout
=
encoder_cfg
.
attention_dropout_rate
,
max_sequence_length
=
encoder_cfg
.
max_position_embeddings
,
type_vocab_size
=
encoder_cfg
.
type_vocab_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
output_range
=
encoder_cfg
.
output_range
,
embedding_width
=
encoder_cfg
.
embedding_size
,
norm_first
=
encoder_cfg
.
norm_first
)
return
encoder
# pylint: disable=protected-access
def
convert
(
encoder
,
allenai_model
):
"""Convert AllenAI Longformer to the one in the codebase."""
num_layers
=
encoder
.
_config
[
"num_layers"
]
num_attention_heads
=
encoder
.
_config
[
"num_attention_heads"
]
hidden_size
=
encoder
.
_config
[
"hidden_size"
]
head_size
=
hidden_size
//
num_attention_heads
assert
head_size
*
num_attention_heads
==
hidden_size
encoder
.
_embedding_layer
.
set_weights
(
[
allenai_model
[
"embeddings.word_embeddings.weight"
]])
encoder
.
_embedding_norm_layer
.
set_weights
([
allenai_model
[
"embeddings.LayerNorm.weight"
],
allenai_model
[
"embeddings.LayerNorm.bias"
]
])
encoder
.
_type_embedding_layer
.
set_weights
([
np
.
repeat
(
allenai_model
[
"embeddings.token_type_embeddings.weight"
],
2
,
axis
=
0
)
])
encoder
.
_position_embedding_layer
.
set_weights
(
[
allenai_model
[
"embeddings.position_embeddings.weight"
]])
encoder
.
_pooler_layer
.
set_weights
([
allenai_model
[
"pooler.dense.weight"
],
allenai_model
[
"pooler.dense.bias"
]
])
for
layer_num
in
range
(
num_layers
):
encoder
.
_transformer_layers
[
layer_num
].
_attention_layer
.
_global_key_dense
.
set_weights
([
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.self.key_global.weight"
].
T
.
reshape
(
(
hidden_size
,
num_attention_heads
,
head_size
)),
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.self.key_global.bias"
]
.
reshape
((
num_attention_heads
,
head_size
))
])
encoder
.
_transformer_layers
[
layer_num
].
_attention_layer
.
_global_query_dense
.
set_weights
([
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.self.query_global.weight"
]
.
T
.
reshape
((
hidden_size
,
num_attention_heads
,
head_size
)),
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.self.query_global.bias"
]
.
reshape
((
num_attention_heads
,
head_size
))
])
encoder
.
_transformer_layers
[
layer_num
].
_attention_layer
.
_global_value_dense
.
set_weights
([
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.self.value_global.weight"
]
.
T
.
reshape
((
hidden_size
,
num_attention_heads
,
head_size
)),
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.self.value_global.bias"
]
.
reshape
((
num_attention_heads
,
head_size
))
])
encoder
.
_transformer_layers
[
layer_num
].
_attention_layer
.
_key_dense
.
set_weights
([
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.self.key.weight"
].
T
.
reshape
(
(
hidden_size
,
num_attention_heads
,
head_size
)),
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.self.key_global.bias"
]
.
reshape
((
num_attention_heads
,
head_size
))
])
encoder
.
_transformer_layers
[
layer_num
].
_attention_layer
.
_query_dense
.
set_weights
([
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.self.query.weight"
].
T
.
reshape
((
hidden_size
,
num_attention_heads
,
head_size
)),
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.self.query.bias"
].
reshape
(
(
num_attention_heads
,
head_size
))
])
encoder
.
_transformer_layers
[
layer_num
].
_attention_layer
.
_value_dense
.
set_weights
([
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.self.value.weight"
].
T
.
reshape
((
hidden_size
,
num_attention_heads
,
head_size
)),
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.self.value.bias"
].
reshape
(
(
num_attention_heads
,
head_size
))
])
encoder
.
_transformer_layers
[
layer_num
].
_attention_layer
.
_output_dense
.
set_weights
([
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.output.dense.weight"
].
T
,
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.output.dense.bias"
]
])
encoder
.
_transformer_layers
[
layer_num
].
_attention_layer_norm
.
set_weights
([
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.output.LayerNorm.weight"
],
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.output.LayerNorm.bias"
]
])
encoder
.
_transformer_layers
[
layer_num
].
_intermediate_dense
.
set_weights
([
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.intermediate.dense.weight"
].
T
,
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.intermediate.dense.bias"
]
])
encoder
.
_transformer_layers
[
layer_num
].
_output_dense
.
set_weights
([
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.output.dense.weight"
].
T
,
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.output.dense.bias"
]
])
encoder
.
_transformer_layers
[
layer_num
].
_output_layer_norm
.
set_weights
([
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.output.LayerNorm.weight"
],
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.output.LayerNorm.bias"
]
])
def
convert_checkpoint
(
output_path
):
"""Converts and save the checkpoint."""
output_dir
,
_
=
os
.
path
.
split
(
output_path
)
tf
.
io
.
gfile
.
makedirs
(
output_dir
)
encoder
=
_create_longformer_model
()
allenai_model
=
_get_pytorch_longformer_model
()
sequence_length
=
128
batch_size
=
2
word_id_data
=
np
.
random
.
randint
(
10
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
type_id_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
inputs
=
{
"input_word_ids"
:
word_id_data
,
"input_mask"
:
mask_data
,
"input_type_ids"
:
type_id_data
,
}
encoder
(
inputs
)
convert
(
encoder
,
allenai_model
)
tf
.
train
.
Checkpoint
(
encoder
=
encoder
).
write
(
output_path
)
def
main
(
_
):
convert_checkpoint
(
"longformer-4096/longformer"
)
if
__name__
==
"__main__"
:
app
.
run
(
main
)
official/projects/longformer/utils/longformer_tokenizer_to_tfrecord.py
0 → 100644
View file @
8b641b13
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Longformer training examples to Tfrecord."""
import
collections
import
os
import
datasets
import
tensorflow
as
tf
import
transformers
pretrained_lm
=
"allenai/longformer-base-4096"
task_name
=
"mnli"
save_path
=
"./"
raw_datasets
=
datasets
.
load_dataset
(
"glue"
,
task_name
,
cache_dir
=
None
)
label_list
=
raw_datasets
[
"train"
].
features
[
"label"
].
names
num_labels
=
len
(
label_list
)
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
pretrained_lm
,
use_fast
=
True
,
)
task_to_keys
=
{
"cola"
:
(
"sentence"
,
None
),
"mnli"
:
(
"premise"
,
"hypothesis"
),
"mrpc"
:
(
"sentence1"
,
"sentence2"
),
"qnli"
:
(
"question"
,
"sentence"
),
"qqp"
:
(
"question1"
,
"question2"
),
"rte"
:
(
"sentence1"
,
"sentence2"
),
"sst2"
:
(
"sentence"
,
None
),
"stsb"
:
(
"sentence1"
,
"sentence2"
),
"wnli"
:
(
"sentence1"
,
"sentence2"
),
}
sentence1_key
,
sentence2_key
=
task_to_keys
[
task_name
]
padding
=
"max_length"
# make sure this is the same with model input size.
max_seq_length
=
512
def
preprocess_function
(
examples
):
# Tokenize the texts
args
=
((
examples
[
sentence1_key
],)
if
sentence2_key
is
None
else
(
examples
[
sentence1_key
],
examples
[
sentence2_key
]))
result
=
tokenizer
(
*
args
,
padding
=
padding
,
max_length
=
max_seq_length
,
truncation
=
True
)
return
result
raw_datasets
=
raw_datasets
.
map
(
preprocess_function
,
batched
=
True
,
desc
=
"Running tokenizer on dataset"
,
)
train_dataset
=
raw_datasets
[
"train"
]
eval_dataset
=
raw_datasets
[
"validation_matched"
if
task_name
==
"mnli"
else
"validation"
]
print
(
"train_dataset"
,
train_dataset
[
0
])
print
(
"eval_dataset"
,
eval_dataset
[
0
])
def
file_based_convert_examples_to_features
(
examples
,
output_file
):
"""Convert a set of `InputExample`s to a TFRecord file."""
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
output_file
))
writer
=
tf
.
io
.
TFRecordWriter
(
output_file
)
for
ex_index
,
example
in
enumerate
(
examples
):
if
ex_index
%
10000
==
0
:
print
(
f
"Writing example
{
ex_index
}
of
{
len
(
examples
)
}
"
)
def
create_int_feature
(
values
):
f
=
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
list
(
values
)))
return
f
features
=
collections
.
OrderedDict
()
features
[
"input_ids"
]
=
create_int_feature
(
example
[
"input_ids"
])
features
[
"input_mask"
]
=
create_int_feature
(
example
[
"attention_mask"
])
features
[
"segment_ids"
]
=
create_int_feature
([
0
]
*
len
(
example
[
"attention_mask"
]))
features
[
"label_ids"
]
=
create_int_feature
([
example
[
"label"
]])
features
[
"is_real_example"
]
=
create_int_feature
([
1
])
features
[
"example_id"
]
=
create_int_feature
([
example
[
"idx"
]])
tf_example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
writer
.
write
(
tf_example
.
SerializeToString
())
writer
.
close
()
file_based_convert_examples_to_features
(
train_dataset
,
os
.
path
.
join
(
save_path
,
f
"
{
pretrained_lm
.
replace
(
'/'
,
'_'
)
}
_train.tf_record"
))
file_based_convert_examples_to_features
(
eval_dataset
,
os
.
path
.
join
(
save_path
,
f
"
{
pretrained_lm
.
replace
(
'/'
,
'_'
)
}
_eval.tf_record"
))
official/projects/movinet/README.md
View file @
8b641b13
...
...
@@ -176,8 +176,7 @@ devices. See the [TF Lite Example](#tf-lite-example) to export and run your own
models. We also provide
[
quantized TF Lite binaries via TF Hub
](
https://tfhub.dev/s?deployment-format=lite&q=movinet
)
.
For reference, MoViNet-A0-Stream runs with a similar latency to
[MobileNetV3-Large]
(https://tfhub.dev/google/imagenet/mobilenet_v3_large_100_224/classification/)
[
MobileNetV3-Large
](
https://tfhub.dev/google/imagenet/mobilenet_v3_large_100_224/classification/
)
with +5% accuracy on Kinetics 600.
| Model Name | Input Shape | Pixel 4 Latency
\*
| x86 Latency
\*
| TF Lite Binary |
...
...
official/projects/movinet/modeling/movinet.py
View file @
8b641b13
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Contains definitions of Mobile Video Networks.
Reference: https://arxiv.org/pdf/2103.11511.pdf
...
...
official/projects/movinet/modeling/movinet_layers.py
View file @
8b641b13
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Contains common building blocks for MoViNets.
Reference: https://arxiv.org/pdf/2103.11511.pdf
...
...
official/projects/movinet/modeling/movinet_layers_test.py
View file @
8b641b13
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for movinet_layers.py."""
from
absl.testing
import
parameterized
...
...
official/projects/movinet/modeling/movinet_model_test.py
View file @
8b641b13
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for movinet_model.py."""
from
absl.testing
import
parameterized
...
...
official/projects/movinet/modeling/movinet_test.py
View file @
8b641b13
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for movinet.py."""
from
absl.testing
import
parameterized
...
...
official/projects/movinet/tools/export_saved_model.py
View file @
8b641b13
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
r
"""Exports models to tf.saved_model.
Export example:
...
...
Prev
1
2
3
4
5
6
7
8
…
21
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment