Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8b641b13
Unverified
Commit
8b641b13
authored
Mar 26, 2022
by
Srihari Humbarwadi
Committed by
GitHub
Mar 26, 2022
Browse files
Merge branch 'tensorflow:master' into panoptic-deeplab
parents
7cffacfe
357fa547
Changes
411
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2870 additions
and
15 deletions
+2870
-15
official/projects/longformer/experiments/glue_mnli.yaml
official/projects/longformer/experiments/glue_mnli.yaml
+47
-0
official/projects/longformer/experiments/glue_mnli_allenai.yaml
...al/projects/longformer/experiments/glue_mnli_allenai.yaml
+48
-0
official/projects/longformer/experiments/pretraining_512.yaml
...cial/projects/longformer/experiments/pretraining_512.yaml
+74
-0
official/projects/longformer/longformer.py
official/projects/longformer/longformer.py
+69
-0
official/projects/longformer/longformer_attention.py
official/projects/longformer/longformer_attention.py
+1082
-0
official/projects/longformer/longformer_attention_test.py
official/projects/longformer/longformer_attention_test.py
+306
-0
official/projects/longformer/longformer_encoder.py
official/projects/longformer/longformer_encoder.py
+365
-0
official/projects/longformer/longformer_encoder_block.py
official/projects/longformer/longformer_encoder_block.py
+340
-0
official/projects/longformer/longformer_encoder_test.py
official/projects/longformer/longformer_encoder_test.py
+97
-0
official/projects/longformer/longformer_experiments.py
official/projects/longformer/longformer_experiments.py
+123
-0
official/projects/longformer/train.py
official/projects/longformer/train.py
+6
-7
official/projects/longformer/utils/convert_pretrained_pytorch_checkpoint_to_tf.py
...rmer/utils/convert_pretrained_pytorch_checkpoint_to_tf.py
+200
-0
official/projects/longformer/utils/longformer_tokenizer_to_tfrecord.py
...ects/longformer/utils/longformer_tokenizer_to_tfrecord.py
+112
-0
official/projects/movinet/README.md
official/projects/movinet/README.md
+1
-2
official/projects/movinet/modeling/movinet.py
official/projects/movinet/modeling/movinet.py
+0
-1
official/projects/movinet/modeling/movinet_layers.py
official/projects/movinet/modeling/movinet_layers.py
+0
-1
official/projects/movinet/modeling/movinet_layers_test.py
official/projects/movinet/modeling/movinet_layers_test.py
+0
-1
official/projects/movinet/modeling/movinet_model_test.py
official/projects/movinet/modeling/movinet_model_test.py
+0
-1
official/projects/movinet/modeling/movinet_test.py
official/projects/movinet/modeling/movinet_test.py
+0
-1
official/projects/movinet/tools/export_saved_model.py
official/projects/movinet/tools/export_saved_model.py
+0
-1
No files found.
Too many changes to show.
To preserve performance only
411 of 411+
files are displayed.
Plain diff
Email patch
official/projects/longformer/experiments/glue_mnli.yaml
0 → 100644
View file @
8b641b13
task
:
hub_module_url
:
'
'
model
:
num_classes
:
3
encoder
:
type
:
any
any
:
max_position_embeddings
:
512
attention_window
:
[
32
,
32
,
32
,
32
,
32
,
32
,
32
,
32
,
32
,
32
,
32
,
32
]
global_attention_size
:
1
metric_type
:
'
accuracy'
train_data
:
drop_remainder
:
true
global_batch_size
:
32
input_path
:
TODO
is_training
:
true
seq_length
:
128
validation_data
:
drop_remainder
:
true
global_batch_size
:
32
input_path
:
TODO
is_training
:
false
seq_length
:
128
trainer
:
checkpoint_interval
:
1000
continuous_eval_timeout
:
7200
optimizer_config
:
learning_rate
:
polynomial
:
decay_steps
:
61359
end_learning_rate
:
0.0
initial_learning_rate
:
3.0e-05
power
:
1.0
type
:
polynomial
optimizer
:
type
:
adamw
warmup
:
polynomial
:
power
:
1
warmup_steps
:
6136
type
:
polynomial
steps_per_loop
:
100
summary_interval
:
100
# Training data size 392,702 examples, 5 epochs.
train_steps
:
61359
validation_interval
:
2000
validation_steps
:
307
official/projects/longformer/experiments/glue_mnli_allenai.yaml
0 → 100644
View file @
8b641b13
task
:
hub_module_url
:
'
'
model
:
num_classes
:
3
encoder
:
type
:
any
any
:
max_position_embeddings
:
4098
attention_window
:
[
128
,
128
,
128
,
128
,
128
,
128
,
128
,
128
,
128
,
128
,
128
,
128
]
global_attention_size
:
1
vocab_size
:
50265
metric_type
:
'
accuracy'
train_data
:
drop_remainder
:
true
global_batch_size
:
32
input_path
:
TODO
is_training
:
true
seq_length
:
512
validation_data
:
drop_remainder
:
true
global_batch_size
:
32
input_path
:
TODO
is_training
:
false
seq_length
:
512
trainer
:
checkpoint_interval
:
1000
continuous_eval_timeout
:
7200
optimizer_config
:
learning_rate
:
polynomial
:
decay_steps
:
61359
end_learning_rate
:
0.0
initial_learning_rate
:
3.0e-05
power
:
1.0
type
:
polynomial
optimizer
:
type
:
adamw
warmup
:
polynomial
:
power
:
1
warmup_steps
:
6136
type
:
polynomial
steps_per_loop
:
1000
summary_interval
:
1000
# Training data size 392,702 examples, 5 epochs.
train_steps
:
61359
validation_interval
:
2000
validation_steps
:
307
official/projects/longformer/experiments/pretraining_512.yaml
0 → 100644
View file @
8b641b13
task
:
init_checkpoint
:
"
"
model
:
cls_heads
:
[
{
activation
:
tanh
,
cls_token_idx
:
0
,
dropout_rate
:
0.1
,
inner_dim
:
768
,
name
:
next_sentence
,
num_classes
:
2
,
},
]
encoder
:
type
:
any
any
:
attention_dropout_rate
:
0.1
dropout_rate
:
0.1
embedding_size
:
768
hidden_activation
:
gelu
hidden_size
:
768
initializer_range
:
0.02
intermediate_size
:
3072
max_position_embeddings
:
512
num_attention_heads
:
12
num_layers
:
12
type_vocab_size
:
2
vocab_size
:
30522
attention_window
:
[
32
,
32
,
32
,
32
,
32
,
32
,
32
,
32
,
32
,
32
,
32
,
32
]
global_attention_size
:
1
train_data
:
drop_remainder
:
true
global_batch_size
:
256
input_path
:
gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00000-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00001-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00002-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00003-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00004-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00005-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00006-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00007-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00008-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00009-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00010-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00011-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00012-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00013-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00014-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00015-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00016-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00017-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00018-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00019-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00020-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00021-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00022-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00023-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00024-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00025-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00026-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00027-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00028-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00029-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00030-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00031-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00032-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00033-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00034-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00035-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00036-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00037-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00038-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00039-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00040-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00041-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00042-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00043-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00044-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00045-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00046-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00047-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00048-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00049-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00050-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00051-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00052-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00053-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00054-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00055-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00056-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00057-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00058-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00059-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00060-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00061-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00062-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00063-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00064-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00065-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00066-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00067-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00068-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00069-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00070-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00071-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00072-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00073-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00074-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00075-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00076-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00077-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00078-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00079-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00080-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00081-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00082-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00083-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00084-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00085-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00086-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00087-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00088-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00089-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00090-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00091-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00092-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00093-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00094-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00095-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00096-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00097-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00098-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00099-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00100-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00101-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00102-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00103-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00104-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00105-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00106-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00107-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00108-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00109-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00110-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00111-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00112-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00113-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00114-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00115-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00116-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00117-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00118-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00119-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00120-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00121-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00122-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00123-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00124-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00125-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00126-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00127-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00128-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00129-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00130-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00131-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00132-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00133-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00134-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00135-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00136-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00137-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00138-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00139-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00140-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00141-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00142-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00143-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00144-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00145-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00146-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00147-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00148-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00149-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00150-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00151-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00152-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00153-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00154-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00155-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00156-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00157-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00158-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00159-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00160-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00161-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00162-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00163-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00164-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00165-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00166-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00167-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00168-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00169-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00170-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00171-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00172-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00173-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00174-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00175-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00176-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00177-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00178-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00179-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00180-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00181-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00182-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00183-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00184-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00185-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00186-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00187-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00188-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00189-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00190-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00191-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00192-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00193-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00194-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00195-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00196-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00197-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00198-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00199-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00200-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00201-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00202-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00203-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00204-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00205-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00206-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00207-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00208-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00209-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00210-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00211-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00212-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00213-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00214-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00215-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00216-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00217-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00218-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00219-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00220-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00221-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00222-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00223-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00224-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00225-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00226-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00227-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00228-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00229-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00230-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00231-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00232-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00233-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00234-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00235-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00236-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00237-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00238-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00239-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00240-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00241-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00242-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00243-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00244-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00245-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00246-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00247-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00248-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00249-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00250-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00251-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00252-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00253-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00254-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00255-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00256-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00257-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00258-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00259-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00260-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00261-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00262-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00263-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00264-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00265-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00266-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00267-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00268-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00269-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00270-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00271-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00272-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00273-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00274-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00275-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00276-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00277-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00278-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00279-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00280-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00281-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00282-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00283-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00284-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00285-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00286-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00287-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00288-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00289-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00290-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00291-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00292-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00293-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00294-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00295-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00296-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00297-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00298-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00299-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00300-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00301-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00302-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00303-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00304-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00305-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00306-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00307-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00308-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00309-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00310-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00311-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00312-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00313-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00314-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00315-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00316-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00317-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00318-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00319-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00320-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00321-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00322-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00323-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00324-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00325-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00326-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00327-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00328-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00329-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00330-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00331-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00332-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00333-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00334-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00335-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00336-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00337-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00338-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00339-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00340-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00341-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00342-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00343-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00344-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00345-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00346-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00347-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00348-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00349-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00350-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00351-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00352-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00353-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00354-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00355-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00356-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00357-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00358-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00359-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00360-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00361-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00362-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00363-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00364-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00365-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00366-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00367-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00368-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00369-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00370-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00371-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00372-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00373-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00374-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00375-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00376-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00377-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00378-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00379-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00380-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00381-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00382-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00383-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00384-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00385-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00386-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00387-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00388-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00389-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00390-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00391-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00392-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00393-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00394-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00395-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00396-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00397-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00398-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00399-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00400-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00401-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00402-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00403-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00404-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00405-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00406-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00407-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00408-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00409-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00410-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00411-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00412-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00413-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00414-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00415-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00416-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00417-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00418-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00419-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00420-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00421-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00422-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00423-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00424-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00425-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00426-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00427-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00428-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00429-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00430-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00431-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00432-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00433-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00434-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00435-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00436-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00437-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00438-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00439-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00440-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00441-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00442-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00443-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00444-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00445-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00446-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00447-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00448-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00449-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00450-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00451-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00452-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00453-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00454-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00455-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00456-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00457-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00458-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00459-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00460-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00461-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00462-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00463-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00464-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00465-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00466-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00467-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00468-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00469-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00470-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00471-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00472-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00473-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00474-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00475-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00476-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00477-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00478-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00479-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00480-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00481-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00482-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00483-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00484-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00485-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00486-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00487-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00488-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00489-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00490-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00491-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00492-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00493-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00494-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00495-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00496-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00497-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00498-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00499-of-00500
is_training
:
true
max_predictions_per_seq
:
76
seq_length
:
512
use_next_sentence_label
:
true
use_position_id
:
false
validation_data
:
drop_remainder
:
true
global_batch_size
:
256
input_path
:
TODO
is_training
:
false
max_predictions_per_seq
:
76
seq_length
:
512
use_next_sentence_label
:
true
use_position_id
:
false
trainer
:
checkpoint_interval
:
20000
max_to_keep
:
5
optimizer_config
:
learning_rate
:
polynomial
:
cycle
:
false
decay_steps
:
1000000
end_learning_rate
:
0.0
initial_learning_rate
:
0.0001
power
:
1.0
type
:
polynomial
optimizer
:
type
:
adamw
warmup
:
polynomial
:
power
:
1
warmup_steps
:
10000
type
:
polynomial
steps_per_loop
:
50
summary_interval
:
50
train_steps
:
1000000
validation_interval
:
1000
validation_steps
:
64
official/projects/longformer/longformer.py
0 → 100644
View file @
8b641b13
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Longformer model configurations and instantiation methods."""
import
dataclasses
from
typing
import
List
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.modeling.hyperparams
import
base_config
from
official.nlp.configs
import
encoders
from
official.projects.longformer.longformer_encoder
import
LongformerEncoder
@
dataclasses
.
dataclass
class
LongformerEncoderConfig
(
encoders
.
BertEncoderConfig
):
"""Extra paramerters for Longformer configs.
Attributes:
attention_window: list of ints representing the window size for each layer.
global_attention_size: the size of global attention used for each token.
pad_token_id: the token id for the pad token
"""
attention_window
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
global_attention_size
:
int
=
0
pad_token_id
:
int
=
1
@
base_config
.
bind
(
LongformerEncoderConfig
)
def
get_encoder
(
encoder_cfg
:
LongformerEncoderConfig
):
"""Gets a 'LongformerEncoder' object.
Args:
encoder_cfg: A 'LongformerEncoderConfig'.
Returns:
A encoder object.
"""
encoder
=
LongformerEncoder
(
attention_window
=
encoder_cfg
.
attention_window
,
global_attention_size
=
encoder_cfg
.
global_attention_size
,
vocab_size
=
encoder_cfg
.
vocab_size
,
hidden_size
=
encoder_cfg
.
hidden_size
,
num_layers
=
encoder_cfg
.
num_layers
,
num_attention_heads
=
encoder_cfg
.
num_attention_heads
,
inner_dim
=
encoder_cfg
.
intermediate_size
,
inner_activation
=
tf_utils
.
get_activation
(
encoder_cfg
.
hidden_activation
),
output_dropout
=
encoder_cfg
.
dropout_rate
,
attention_dropout
=
encoder_cfg
.
attention_dropout_rate
,
max_sequence_length
=
encoder_cfg
.
max_position_embeddings
,
type_vocab_size
=
encoder_cfg
.
type_vocab_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
output_range
=
encoder_cfg
.
output_range
,
embedding_width
=
encoder_cfg
.
embedding_size
,
norm_first
=
encoder_cfg
.
norm_first
)
return
encoder
official/projects/longformer/longformer_attention.py
0 → 100644
View file @
8b641b13
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Longformer attention block. Modified From huggingface/transformers."""
# pylint: disable=g-classes-have-attributes
import
math
import
string
import
numpy
as
np
import
tensorflow
as
tf
from
official.modeling.tf_utils
import
get_shape_list
_CHR_IDX
=
string
.
ascii_lowercase
def
_build_attention_equation
(
rank
,
attn_axes
):
"""Builds einsum equations for the attention computation.
Query, key, value inputs after projection are expected to have the shape as:
`(bs, <non-attention dims>, <attention dims>, num_heads, channels)`.
`bs` and `<non-attention dims>` are treated as `<batch dims>`.
The attention operations can be generalized:
(1) Query-key dot product:
`(<batch dims>, <query attention dims>, num_heads, channels), (<batch dims>,
<key attention dims>, num_heads, channels) -> (<batch dims>,
num_heads, <query attention dims>, <key attention dims>)`
(2) Combination:
`(<batch dims>, num_heads, <query attention dims>, <key attention dims>),
(<batch dims>, <value attention dims>, num_heads, channels) -> (<batch dims>,
<query attention dims>, num_heads, channels)`
Args:
rank: Rank of query, key, value tensors.
attn_axes: List/tuple of axes, `[-1, rank)`, that attention will be applied
to.
Returns:
Einsum equations.
"""
target_notation
=
_CHR_IDX
[:
rank
]
# `batch_dims` includes the head dim.
batch_dims
=
tuple
(
np
.
delete
(
range
(
rank
),
attn_axes
+
(
rank
-
1
,)))
letter_offset
=
rank
source_notation
=
""
for
i
in
range
(
rank
):
if
i
in
batch_dims
or
i
==
rank
-
1
:
source_notation
+=
target_notation
[
i
]
else
:
source_notation
+=
_CHR_IDX
[
letter_offset
]
letter_offset
+=
1
product_notation
=
""
.
join
([
target_notation
[
i
]
for
i
in
batch_dims
]
+
[
target_notation
[
i
]
for
i
in
attn_axes
]
+
[
source_notation
[
i
]
for
i
in
attn_axes
])
dot_product_equation
=
f
"
{
source_notation
}
,
{
target_notation
}
->
{
product_notation
}
"
attn_scores_rank
=
len
(
product_notation
)
combine_equation
=
f
"
{
product_notation
}
,
{
source_notation
}
->
{
target_notation
}
"
return
dot_product_equation
,
combine_equation
,
attn_scores_rank
def
_build_proj_equation
(
free_dims
,
bound_dims
,
output_dims
):
"""Builds an einsum equation for projections inside multi-head attention."""
input_str
=
""
kernel_str
=
""
output_str
=
""
bias_axes
=
""
letter_offset
=
0
for
i
in
range
(
free_dims
):
char
=
_CHR_IDX
[
i
+
letter_offset
]
input_str
+=
char
output_str
+=
char
letter_offset
+=
free_dims
for
i
in
range
(
bound_dims
):
char
=
_CHR_IDX
[
i
+
letter_offset
]
input_str
+=
char
kernel_str
+=
char
letter_offset
+=
bound_dims
for
i
in
range
(
output_dims
):
char
=
_CHR_IDX
[
i
+
letter_offset
]
kernel_str
+=
char
output_str
+=
char
bias_axes
+=
char
equation
=
f
"
{
input_str
}
,
{
kernel_str
}
->
{
output_str
}
"
return
equation
,
bias_axes
,
len
(
output_str
)
def
_get_output_shape
(
output_rank
,
known_last_dims
):
return
[
None
]
*
(
output_rank
-
len
(
known_last_dims
))
+
list
(
known_last_dims
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
class
LongformerAttention
(
tf
.
keras
.
layers
.
MultiHeadAttention
):
"""LongformerAttention.
Args:
attention_window: int representing the window size for attention.
layer_id: int of the id of the layer.
global_attention_size: the size of global attention used for each token.
"""
def
__init__
(
self
,
attention_window
,
layer_id
,
global_attention_size
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
_layer_id
=
layer_id
self
.
_attention_window
=
attention_window
assert
(
self
.
_attention_window
%
2
==
0
),
(
f
"`attention_window` for layer
{
self
.
_layer_id
}
has to be an even "
f
"value. Given
{
self
.
attention_window
}
"
)
assert
(
self
.
_attention_window
>
0
),
(
f
"`attention_window` for layer
{
self
.
_layer_id
}
has to be positive. "
f
"Given
{
self
.
attention_window
}
"
)
self
.
_one_sided_attn_window_size
=
self
.
_attention_window
//
2
self
.
global_attention_size
=
global_attention_size
def
_build_from_signature
(
self
,
query
,
value
,
key
=
None
):
"""Builds layers and variables.
Once the method is called, self._built_from_signature will be set to True.
Args:
query: Query tensor or TensorShape.
value: Value tensor or TensorShape.
key: Key tensor or TensorShape.
"""
self
.
_built_from_signature
=
True
if
hasattr
(
query
,
"shape"
):
self
.
_query_shape
=
tf
.
TensorShape
(
query
.
shape
)
else
:
self
.
_query_shape
=
tf
.
TensorShape
(
query
)
if
hasattr
(
value
,
"shape"
):
self
.
_value_shape
=
tf
.
TensorShape
(
value
.
shape
)
else
:
self
.
_value_shape
=
tf
.
TensorShape
(
value
)
if
key
is
None
:
self
.
_key_shape
=
self
.
_value_shape
elif
hasattr
(
key
,
"shape"
):
self
.
_key_shape
=
tf
.
TensorShape
(
key
.
shape
)
else
:
self
.
_key_shape
=
tf
.
TensorShape
(
key
)
common_kwargs
=
dict
(
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
)
# Any setup work performed only once should happen in an `init_scope`
# to avoid creating symbolic Tensors that will later pollute any eager
# operations.
# with tf_utils.maybe_init_scope(self):
# TODO(crickwu): check whether tf_utils.maybe_init_scope(self) (keras)
# is needed.
free_dims
=
self
.
_query_shape
.
rank
-
1
einsum_equation
,
bias_axes
,
output_rank
=
_build_proj_equation
(
free_dims
,
bound_dims
=
1
,
output_dims
=
2
)
self
.
_query_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
einsum_equation
,
output_shape
=
_get_output_shape
(
output_rank
-
1
,
[
self
.
_num_heads
,
self
.
_key_dim
]),
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
name
=
"query"
,
**
common_kwargs
)
self
.
_global_query_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
einsum_equation
,
output_shape
=
_get_output_shape
(
output_rank
-
1
,
[
self
.
_num_heads
,
self
.
_key_dim
]),
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
name
=
"global_query"
,
**
common_kwargs
)
einsum_equation
,
bias_axes
,
output_rank
=
_build_proj_equation
(
self
.
_key_shape
.
rank
-
1
,
bound_dims
=
1
,
output_dims
=
2
)
self
.
_key_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
einsum_equation
,
output_shape
=
_get_output_shape
(
output_rank
-
1
,
[
self
.
_num_heads
,
self
.
_key_dim
]),
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
name
=
"key"
,
**
common_kwargs
)
self
.
_global_key_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
einsum_equation
,
output_shape
=
_get_output_shape
(
output_rank
-
1
,
[
self
.
_num_heads
,
self
.
_key_dim
]),
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
name
=
"global_key"
,
**
common_kwargs
)
einsum_equation
,
bias_axes
,
output_rank
=
_build_proj_equation
(
self
.
_value_shape
.
rank
-
1
,
bound_dims
=
1
,
output_dims
=
2
)
self
.
_value_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
einsum_equation
,
output_shape
=
_get_output_shape
(
output_rank
-
1
,
[
self
.
_num_heads
,
self
.
_value_dim
]),
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
name
=
"value"
,
**
common_kwargs
)
self
.
_global_value_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
einsum_equation
,
output_shape
=
_get_output_shape
(
output_rank
-
1
,
[
self
.
_num_heads
,
self
.
_value_dim
]),
bias_axes
=
bias_axes
if
self
.
_use_bias
else
None
,
name
=
"global_value"
,
**
common_kwargs
)
# Builds the attention computations for multi-head dot product attention.
# These computations could be wrapped into the keras attention layer once
# it support mult-head einsum computations.
self
.
_build_attention
(
output_rank
)
self
.
_global_dropout_layer
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout
)
# self._output_dense = self._make_output_dense(
# free_dims, common_kwargs, "attention_output")
self
.
_output_dense
=
tf
.
keras
.
layers
.
Dense
(
units
=
self
.
_num_heads
*
self
.
_key_dim
,
name
=
"dense"
,
**
common_kwargs
)
def
call
(
self
,
hidden_states
,
attention_mask
=
None
,
is_index_masked
=
None
,
is_index_global_attn
=
None
,
training
=
None
):
"""Applies Dot-product attention with query, key, value tensors.
This function defines the computation inside `call` with projected
multi-head Q, K, V inputs. Users can override this function for customized
attention implementation.
Args:
hidden_states: inputs for generating query, key and value tensors.
attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
attention to certain positions.
is_index_masked: boolean indicating whether the index is masked.
is_index_global_attn: boolean indicating whether the index is global
attention.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
Returns:
attention_output: Multi-headed outputs of attention computation.
"""
if
not
self
.
_built_from_signature
:
self
.
_build_from_signature
(
query
=
hidden_states
,
value
=
hidden_states
,
key
=
hidden_states
)
# N = `num_attention_heads`
# H = `size_per_head`
# `query` = [B, T, N ,H]
query
=
self
.
_query_dense
(
hidden_states
)
# `key` = [B, S, N, H]
key
=
self
.
_key_dense
(
hidden_states
)
# `value` = [B, S, N, H]
value
=
self
.
_value_dense
(
hidden_states
)
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query
=
tf
.
multiply
(
query
,
1.0
/
math
.
sqrt
(
float
(
self
.
_key_dim
)))
batch_size
,
seq_len
,
num_heads
,
head_dim
=
get_shape_list
(
query
)
# attn_probs = (batch_size, seq_len, num_heads, window*2+1)
attn_scores
=
self
.
_sliding_chunks_query_key_matmul
(
query
,
key
,
self
.
_one_sided_attn_window_size
)
# diagonal mask with zeros everywhere and -inf inplace of padding
diagonal_mask
=
self
.
_sliding_chunks_query_key_matmul
(
tf
.
ones
(
get_shape_list
(
attention_mask
)),
attention_mask
,
self
.
_one_sided_attn_window_size
,
)
# pad local attention probs
attn_scores
+=
diagonal_mask
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
get_shape_list
(
attn_scores
),
[
batch_size
,
seq_len
,
self
.
_num_heads
,
self
.
_one_sided_attn_window_size
*
2
+
1
],
message
=
f
"attn_probs should be of size "
f
"(
{
batch_size
}
,
{
seq_len
}
,
{
num_heads
}
, "
f
"
{
self
.
_one_sided_attn_window_size
*
2
+
1
}
),"
f
" but is of size
{
get_shape_list
(
attn_scores
)
}
"
,
)
# compute global attn indices required through out forward fn
(
max_num_global_attn_indices
,
is_index_global_attn_nonzero
,
is_local_index_global_attn_nonzero
,
is_local_index_no_global_attn_nonzero
,
)
=
self
.
_get_global_attn_indices
(
is_index_global_attn
,
self
.
global_attention_size
)
# this function is only relevant for global attention
if
self
.
global_attention_size
>
0
:
attn_scores
=
self
.
_concat_with_global_key_attn_probs
(
attn_scores
=
attn_scores
,
query_vectors
=
query
,
key_vectors
=
key
,
max_num_global_attn_indices
=
max_num_global_attn_indices
,
is_index_global_attn_nonzero
=
is_index_global_attn_nonzero
,
is_local_index_global_attn_nonzero
=
is_local_index_global_attn_nonzero
,
is_local_index_no_global_attn_nonzero
=
is_local_index_no_global_attn_nonzero
,
)
else
:
pass
attn_probs
=
tf
.
nn
.
softmax
(
attn_scores
,
axis
=-
1
)
# softmax sometimes inserts NaN if all positions are masked,
# replace them with 0
# Make sure to create a mask with the proper shape:
# if is_global_attn==True => [batch_size, seq_len, self.num_heads,
# self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
# if is_global_attn==False => [batch_size, seq_len, self.num_heads,
# self.one_sided_attn_window_size * 2 + 1]
if
self
.
global_attention_size
>
0
:
masked_index
=
tf
.
tile
(
is_index_masked
[:,
:,
None
,
None
],
(
1
,
1
,
self
.
_num_heads
,
self
.
_one_sided_attn_window_size
*
2
+
max_num_global_attn_indices
+
1
),
)
else
:
masked_index
=
tf
.
tile
(
is_index_masked
[:,
:,
None
,
None
],
(
1
,
1
,
self
.
_num_heads
,
self
.
_one_sided_attn_window_size
*
2
+
1
),
)
attn_probs
=
tf
.
where
(
masked_index
,
tf
.
zeros
(
get_shape_list
(
masked_index
),
dtype
=
attn_probs
.
dtype
),
attn_probs
,
)
layer_head_mask
=
None
if
layer_head_mask
is
not
None
:
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
get_shape_list
(
layer_head_mask
),
[
self
.
_num_heads
],
message
=
f
"Head mask for a single layer should be of size "
f
"
{
(
self
.
_num_heads
)
}
, but is "
f
"
{
get_shape_list
(
layer_head_mask
)
}
"
,
)
attn_probs
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
1
,
-
1
,
1
))
*
attn_probs
# apply dropout
attn_probs
=
self
.
_dropout_layer
(
attn_probs
,
training
=
training
)
value_vectors
=
tf
.
reshape
(
value
,
(
batch_size
,
seq_len
,
self
.
_num_heads
,
self
.
_key_dim
))
# if global attention, compute sum of global and local attn
if
self
.
global_attention_size
>
0
:
attn_output
=
self
.
_compute_attn_output_with_global_indices
(
value_vectors
=
value_vectors
,
attn_probs
=
attn_probs
,
max_num_global_attn_indices
=
max_num_global_attn_indices
,
is_index_global_attn_nonzero
=
is_index_global_attn_nonzero
,
is_local_index_global_attn_nonzero
=
is_local_index_global_attn_nonzero
,
)
else
:
attn_output
=
self
.
_sliding_chunks_matmul_attn_probs_value
(
attn_probs
,
value_vectors
,
self
.
_one_sided_attn_window_size
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
get_shape_list
(
attn_output
),
[
batch_size
,
seq_len
,
self
.
_num_heads
,
head_dim
],
message
=
"Unexpected size"
,
)
attn_output
=
tf
.
reshape
(
attn_output
,
(
batch_size
,
seq_len
,
self
.
_num_heads
*
self
.
_key_dim
))
# FIXME
# compute value for global attention and overwrite to attention output
# TODO(crickwu): remove the redundant computation
if
self
.
global_attention_size
>
0
:
attn_output
,
global_attn_probs
=
self
.
_compute_global_attn_output_from_hidden
(
# pylint: disable=unused-variable
attn_output
=
attn_output
,
hidden_states
=
hidden_states
,
max_num_global_attn_indices
=
max_num_global_attn_indices
,
layer_head_mask
=
layer_head_mask
,
is_local_index_global_attn_nonzero
=
is_local_index_global_attn_nonzero
,
is_index_global_attn_nonzero
=
is_index_global_attn_nonzero
,
is_local_index_no_global_attn_nonzero
=
is_local_index_no_global_attn_nonzero
,
is_index_masked
=
is_index_masked
,
training
=
training
,
)
else
:
global_attn_probs
=
tf
.
zeros
(
(
batch_size
,
self
.
_num_heads
,
max_num_global_attn_indices
,
seq_len
))
# make sure that local attention probabilities are set to 0 for indices of
# global attn
if
self
.
global_attention_size
>
0
:
masked_global_attn_index
=
tf
.
tile
(
is_index_global_attn
[:,
:,
None
,
None
],
(
1
,
1
,
self
.
_num_heads
,
self
.
_one_sided_attn_window_size
*
2
+
max_num_global_attn_indices
+
1
),
)
else
:
masked_global_attn_index
=
tf
.
tile
(
is_index_global_attn
[:,
:,
None
,
None
],
(
1
,
1
,
self
.
_num_heads
,
self
.
_one_sided_attn_window_size
*
2
+
1
),
)
attn_probs
=
tf
.
where
(
masked_global_attn_index
,
tf
.
zeros
(
get_shape_list
(
masked_global_attn_index
),
dtype
=
attn_probs
.
dtype
),
attn_probs
,
)
# we can return extra information here
# (attn_output, attn_probs, global_attn_probs)
return
attn_output
def
get_config
(
self
):
config
=
{
"layer_id"
:
self
.
_layer_id
,
"attention_window"
:
self
.
_one_sided_attn_window_size
,
}
base_config
=
super
().
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
_sliding_chunks_query_key_matmul
(
self
,
query
,
key
,
window_overlap
):
"""Matrix multiplication of query and key tensors.
This multiplication uses a sliding window attention pattern.
This implementation splits the input into overlapping chunks of size
2w (e.g. 512 for pretrained Longformer) with an overlap of size
window_overlap.
Args:
query: query tensor.
key: key tensor.
window_overlap: int.
Returns:
diagonal_attention_scores: tensor.
"""
batch_size
,
seq_len
,
num_heads
,
head_dim
=
get_shape_list
(
query
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
seq_len
%
(
window_overlap
*
2
),
0
,
message
=
f
"Sequence length should be multiple of
{
window_overlap
*
2
}
. "
f
"Given
{
seq_len
}
"
,
)
tf
.
debugging
.
assert_equal
(
get_shape_list
(
query
),
get_shape_list
(
key
),
message
=
f
"Shape of query and key should be equal, but got query: "
f
"
{
get_shape_list
(
query
)
}
and key:
{
get_shape_list
(
key
)
}
"
,
)
chunks_count
=
seq_len
//
window_overlap
-
1
# group batch_size and num_heads dimensions into one,
# then chunk seq_len into chunks of size window_overlap * 2
query
=
tf
.
reshape
(
tf
.
transpose
(
query
,
(
0
,
2
,
1
,
3
)),
(
batch_size
*
num_heads
,
seq_len
,
head_dim
),
)
key
=
tf
.
reshape
(
tf
.
transpose
(
key
,
(
0
,
2
,
1
,
3
)),
(
batch_size
*
num_heads
,
seq_len
,
head_dim
))
chunked_query
=
self
.
_chunk
(
query
,
window_overlap
)
chunked_key
=
self
.
_chunk
(
key
,
window_overlap
)
# matrix multiplication
# bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap
chunked_query
=
tf
.
cast
(
chunked_query
,
dtype
=
chunked_key
.
dtype
)
chunked_attention_scores
=
tf
.
einsum
(
"bcxd,bcyd->bcxy"
,
chunked_query
,
chunked_key
)
# multiply
# convert diagonals into columns
paddings
=
tf
.
convert_to_tensor
([[
0
,
0
],
[
0
,
0
],
[
0
,
1
],
[
0
,
0
]])
diagonal_chunked_attention_scores
=
self
.
_pad_and_transpose_last_two_dims
(
chunked_attention_scores
,
paddings
)
# allocate space for the overall attention matrix where the chunks are
# combined. The last dimension
# has (window_overlap * 2 + 1) columns. The first (window_overlap) columns
# are the window_overlap lower triangles (attention from a word to
# window_overlap previous words). The following column is attention score
# from each word to itself, then
# followed by window_overlap columns for the upper triangle.
# copy parts from diagonal_chunked_attention_scores into the combined matrix
# of attentions - copying the main diagonal and the upper triangle
# TODO(crickwu): This code is most likely not very efficient and should be
# improved.
diagonal_attn_scores_up_triang
=
tf
.
concat
(
[
diagonal_chunked_attention_scores
[:,
:,
:
window_overlap
,
:
window_overlap
+
1
],
diagonal_chunked_attention_scores
[:,
-
1
:,
window_overlap
:,
:
window_overlap
+
1
],
],
axis
=
1
,
)
# - copying the lower triangle
diagonal_attn_scores_low_triang
=
tf
.
concat
(
[
tf
.
zeros
(
(
batch_size
*
num_heads
,
1
,
window_overlap
,
window_overlap
),
dtype
=
diagonal_chunked_attention_scores
.
dtype
,
),
diagonal_chunked_attention_scores
[:,
:,
-
(
window_overlap
+
1
):
-
1
,
window_overlap
+
1
:],
],
axis
=
1
,
)
diagonal_attn_scores_first_chunk
=
tf
.
concat
(
[
tf
.
roll
(
diagonal_chunked_attention_scores
,
shift
=
[
1
,
window_overlap
],
axis
=
[
2
,
3
],
)[:,
:,
:
window_overlap
,
:
window_overlap
],
tf
.
zeros
(
(
batch_size
*
num_heads
,
1
,
window_overlap
,
window_overlap
),
dtype
=
diagonal_chunked_attention_scores
.
dtype
,
),
],
axis
=
1
,
)
first_chunk_mask
=
(
tf
.
tile
(
tf
.
range
(
chunks_count
+
1
)[
None
,
:,
None
,
None
],
(
batch_size
*
num_heads
,
1
,
window_overlap
,
window_overlap
),
)
<
1
)
diagonal_attn_scores_low_triang
=
tf
.
where
(
first_chunk_mask
,
diagonal_attn_scores_first_chunk
,
diagonal_attn_scores_low_triang
,
)
# merging upper and lower triangle
diagonal_attention_scores
=
tf
.
concat
(
[
diagonal_attn_scores_low_triang
,
diagonal_attn_scores_up_triang
],
axis
=-
1
)
# separate batch_size and num_heads dimensions again
diagonal_attention_scores
=
tf
.
transpose
(
tf
.
reshape
(
diagonal_attention_scores
,
(
batch_size
,
num_heads
,
seq_len
,
2
*
window_overlap
+
1
),
),
(
0
,
2
,
1
,
3
),
)
diagonal_attention_scores
=
self
.
_mask_invalid_locations
(
diagonal_attention_scores
,
window_overlap
)
return
diagonal_attention_scores
@
staticmethod
def
_mask_invalid_locations
(
input_tensor
,
window_overlap
):
# create correct upper triangle bool mask
mask_2d_upper
=
tf
.
reverse
(
tf
.
linalg
.
band_part
(
tf
.
ones
(
shape
=
(
window_overlap
,
window_overlap
+
1
)),
-
1
,
0
),
axis
=
[
0
],
)
# pad to full matrix
padding
=
tf
.
convert_to_tensor
(
[[
0
,
get_shape_list
(
input_tensor
)[
1
]
-
window_overlap
],
[
0
,
get_shape_list
(
input_tensor
)[
3
]
-
window_overlap
-
1
]])
# create lower mask
mask_2d
=
tf
.
pad
(
mask_2d_upper
,
padding
)
# combine with upper mask
mask_2d
=
mask_2d
+
tf
.
reverse
(
mask_2d
,
axis
=
[
0
,
1
])
# broadcast to full matrix
mask_4d
=
tf
.
tile
(
mask_2d
[
None
,
:,
None
,
:],
(
get_shape_list
(
input_tensor
)[
0
],
1
,
1
,
1
))
# inf tensor used for masking
inf_tensor
=
-
float
(
"inf"
)
*
tf
.
ones_like
(
input_tensor
)
# mask
input_tensor
=
tf
.
where
(
tf
.
math
.
greater
(
mask_4d
,
0
),
inf_tensor
,
input_tensor
)
return
input_tensor
def
_sliding_chunks_matmul_attn_probs_value
(
self
,
attn_probs
,
value
,
window_overlap
):
"""Same as _sliding_chunks_query_key_matmul but for attn_probs and value."""
batch_size
,
seq_len
,
num_heads
,
head_dim
=
get_shape_list
(
value
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
seq_len
%
(
window_overlap
*
2
),
0
,
message
=
"Seq_len has to be multiple of 2 * window_overlap"
,
)
tf
.
debugging
.
assert_equal
(
get_shape_list
(
attn_probs
)[:
3
],
get_shape_list
(
value
)[:
3
],
message
=
"value and attn_probs must have same dims (except head_dim)"
,
)
tf
.
debugging
.
assert_equal
(
get_shape_list
(
attn_probs
)[
3
],
2
*
window_overlap
+
1
,
message
=
"attn_probs last dim has to be 2 * window_overlap + 1"
,
)
chunks_count
=
seq_len
//
window_overlap
-
1
# group batch_size and num_heads dimensions into one, then chunk seq_len
# into chunks of size 2 window overlap
chunked_attn_probs
=
tf
.
reshape
(
tf
.
transpose
(
attn_probs
,
(
0
,
2
,
1
,
3
)),
(
batch_size
*
num_heads
,
seq_len
//
window_overlap
,
window_overlap
,
2
*
window_overlap
+
1
,
),
)
# group batch_size and num_heads dimensions into one
value
=
tf
.
reshape
(
tf
.
transpose
(
value
,
(
0
,
2
,
1
,
3
)),
(
batch_size
*
num_heads
,
seq_len
,
head_dim
),
)
# pad seq_len with w at the beginning of the sequence and another window
# overlap at the end
paddings
=
tf
.
convert_to_tensor
([[
0
,
0
],
[
window_overlap
,
window_overlap
],
[
0
,
0
]])
padded_value
=
tf
.
pad
(
value
,
paddings
,
constant_values
=-
1
)
# chunk padded_value into chunks of size 3 window overlap and an overlap of
# size window overlap
frame_size
=
3
*
window_overlap
*
head_dim
frame_hop_size
=
(
get_shape_list
(
padded_value
)[
1
]
*
head_dim
-
frame_size
)
//
chunks_count
chunked_value
=
tf
.
signal
.
frame
(
tf
.
reshape
(
padded_value
,
(
batch_size
*
num_heads
,
-
1
)),
frame_size
,
frame_hop_size
,
)
chunked_value
=
tf
.
reshape
(
chunked_value
,
(
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
),
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
get_shape_list
(
chunked_value
),
[
batch_size
*
num_heads
,
chunks_count
+
1
,
3
*
window_overlap
,
head_dim
],
message
=
"Chunked value has the wrong shape"
,
)
chunked_attn_probs
=
self
.
_pad_and_diagonalize
(
chunked_attn_probs
)
context
=
tf
.
einsum
(
"bcwd,bcdh->bcwh"
,
chunked_attn_probs
,
chunked_value
)
context
=
tf
.
transpose
(
tf
.
reshape
(
context
,
(
batch_size
,
num_heads
,
seq_len
,
head_dim
)),
(
0
,
2
,
1
,
3
),
)
return
context
@
staticmethod
def
_pad_and_transpose_last_two_dims
(
hidden_states_padded
,
paddings
):
"""Pads rows and then flips rows and columns."""
hidden_states_padded
=
tf
.
pad
(
hidden_states_padded
,
paddings
)
# padding value is not important because it will be overwritten
batch_size
,
chunk_size
,
seq_length
,
hidden_dim
=
get_shape_list
(
hidden_states_padded
)
hidden_states_padded
=
tf
.
reshape
(
hidden_states_padded
,
(
batch_size
,
chunk_size
,
hidden_dim
,
seq_length
))
return
hidden_states_padded
@
staticmethod
def
_pad_and_diagonalize
(
chunked_hidden_states
):
"""Shifts every row 1 step right, converting columns into diagonals.
Example::
chunked_hidden_states: [ 0.4983, 2.6918, -0.0071, 1.0492,
-1.8348, 0.7672, 0.2986, 0.0285,
-0.7584, 0.4206, -0.0405, 0.1599,
2.0514, -1.1600, 0.5372, 0.2629 ]
window_overlap = num_rows = 4
(pad & diagonalize) =>
[ 0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000
0.0000, -1.8348, 0.7672, 0.2986, 0.0285, 0.0000, 0.0000
0.0000, 0.0000, -0.7584, 0.4206, -0.0405, 0.1599, 0.0000
0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ]
Args:
chunked_hidden_states: tensor.
Returns:
padded_hidden_stategs: tensor.
"""
total_num_heads
,
num_chunks
,
window_overlap
,
hidden_dim
=
get_shape_list
(
chunked_hidden_states
)
paddings
=
tf
.
convert_to_tensor
([[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
window_overlap
+
1
]])
chunked_hidden_states
=
tf
.
pad
(
chunked_hidden_states
,
paddings
)
chunked_hidden_states
=
tf
.
reshape
(
chunked_hidden_states
,
(
total_num_heads
,
num_chunks
,
-
1
))
chunked_hidden_states
=
chunked_hidden_states
[:,
:,
:
-
window_overlap
]
chunked_hidden_states
=
tf
.
reshape
(
chunked_hidden_states
,
(
total_num_heads
,
num_chunks
,
window_overlap
,
window_overlap
+
hidden_dim
),
)
chunked_hidden_states
=
chunked_hidden_states
[:,
:,
:,
:
-
1
]
return
chunked_hidden_states
@
staticmethod
def
_chunk
(
hidden_states
,
window_overlap
):
"""convert into overlapping chunks. Chunk size = 2w, overlap size = w."""
batch_size
,
seq_length
,
hidden_dim
=
get_shape_list
(
hidden_states
)
num_output_chunks
=
2
*
(
seq_length
//
(
2
*
window_overlap
))
-
1
# define frame size and frame stride (similar to convolution)
frame_hop_size
=
window_overlap
*
hidden_dim
frame_size
=
2
*
frame_hop_size
hidden_states
=
tf
.
reshape
(
hidden_states
,
(
batch_size
,
seq_length
*
hidden_dim
))
# chunk with overlap
chunked_hidden_states
=
tf
.
signal
.
frame
(
hidden_states
,
frame_size
,
frame_hop_size
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
get_shape_list
(
chunked_hidden_states
),
[
batch_size
,
num_output_chunks
,
frame_size
],
message
=
f
"Make sure chunking is correctly applied. `Chunked hidden "
f
"states should have output dimension"
f
"
{
[
batch_size
,
frame_size
,
num_output_chunks
]
}
, but got "
f
"
{
get_shape_list
(
chunked_hidden_states
)
}
."
,
)
chunked_hidden_states
=
tf
.
reshape
(
chunked_hidden_states
,
(
batch_size
,
num_output_chunks
,
2
*
window_overlap
,
hidden_dim
),
)
return
chunked_hidden_states
@
staticmethod
def
_get_global_attn_indices
(
is_index_global_attn
,
global_attention_size
):
"""Computes global attn indices required throughout forward pass."""
# All global attention size are fixed through global_attention_size
batch_size
,
_
=
get_shape_list
(
is_index_global_attn
)
max_num_global_attn_indices
=
global_attention_size
row_indices
=
tf
.
range
(
batch_size
)
row_indices
=
tf
.
repeat
(
tf
.
expand_dims
(
row_indices
,
axis
=
0
),
repeats
=
[
global_attention_size
],
axis
=
0
)
row_indices
=
tf
.
reshape
(
row_indices
,
(
batch_size
*
global_attention_size
,
1
))
col_indices
=
tf
.
range
(
global_attention_size
)
col_indices
=
tf
.
repeat
(
tf
.
expand_dims
(
col_indices
,
axis
=
1
),
repeats
=
[
batch_size
],
axis
=
0
)
is_index_global_attn_nonzero
=
tf
.
concat
((
row_indices
,
col_indices
),
axis
=
1
)
# this is actually same as `is_index_global_attn_nonzero`,
# since we assume all global attention are the same size
is_local_index_global_attn_nonzero
=
tf
.
concat
((
row_indices
,
col_indices
),
axis
=
1
)
# empty tensor
is_local_index_no_global_attn_nonzero
=
tf
.
reshape
(
tf
.
expand_dims
(
tf
.
range
(
0
),
axis
=
1
),
(
0
,
2
))
return
(
max_num_global_attn_indices
,
is_index_global_attn_nonzero
,
is_local_index_global_attn_nonzero
,
is_local_index_no_global_attn_nonzero
,
)
def
_concat_with_global_key_attn_probs
(
self
,
attn_scores
,
key_vectors
,
query_vectors
,
max_num_global_attn_indices
,
is_index_global_attn_nonzero
,
is_local_index_global_attn_nonzero
,
is_local_index_no_global_attn_nonzero
,
):
batch_size
=
get_shape_list
(
key_vectors
)[
0
]
# select global key vectors
global_key_vectors
=
tf
.
gather_nd
(
key_vectors
,
is_index_global_attn_nonzero
)
# create only global key vectors
key_vectors_only_global
=
tf
.
scatter_nd
(
is_local_index_global_attn_nonzero
,
global_key_vectors
,
shape
=
(
batch_size
,
max_num_global_attn_indices
,
self
.
_num_heads
,
self
.
_key_dim
,
),
)
# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
attn_probs_from_global_key
=
tf
.
einsum
(
"blhd,bshd->blhs"
,
query_vectors
,
key_vectors_only_global
)
# (batch_size, max_num_global_attn_indices, seq_len, num_heads)
attn_probs_from_global_key_trans
=
tf
.
transpose
(
attn_probs_from_global_key
,
(
0
,
3
,
1
,
2
))
mask_shape
=
(
get_shape_list
(
is_local_index_no_global_attn_nonzero
)[
0
],)
+
tuple
(
get_shape_list
(
attn_probs_from_global_key_trans
)[
-
2
:])
mask
=
tf
.
ones
(
mask_shape
)
*
-
10000.0
mask
=
tf
.
cast
(
mask
,
dtype
=
attn_probs_from_global_key_trans
.
dtype
)
# scatter mask
attn_probs_from_global_key_trans
=
tf
.
tensor_scatter_nd_update
(
attn_probs_from_global_key_trans
,
is_local_index_no_global_attn_nonzero
,
mask
,
)
# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
attn_probs_from_global_key
=
tf
.
transpose
(
attn_probs_from_global_key_trans
,
(
0
,
2
,
3
,
1
))
# concat to attn_probs
# (batch_size, seq_len, num_heads, extra attention count + 2*window+1)
attn_scores
=
tf
.
concat
((
attn_probs_from_global_key
,
attn_scores
),
axis
=-
1
)
return
attn_scores
def
_compute_attn_output_with_global_indices
(
self
,
value_vectors
,
attn_probs
,
max_num_global_attn_indices
,
is_index_global_attn_nonzero
,
is_local_index_global_attn_nonzero
,
):
batch_size
=
get_shape_list
(
attn_probs
)[
0
]
# cut local attn probs to global only
attn_probs_only_global
=
attn_probs
[:,
:,
:,
:
max_num_global_attn_indices
]
# select global value vectors
global_value_vectors
=
tf
.
gather_nd
(
value_vectors
,
is_index_global_attn_nonzero
)
# create only global value vectors
value_vectors_only_global
=
tf
.
scatter_nd
(
is_local_index_global_attn_nonzero
,
global_value_vectors
,
shape
=
(
batch_size
,
max_num_global_attn_indices
,
self
.
_num_heads
,
self
.
_key_dim
,
),
)
# compute attn output only global
attn_output_only_global
=
tf
.
einsum
(
"blhs,bshd->blhd"
,
attn_probs_only_global
,
value_vectors_only_global
)
# reshape attn probs
attn_probs_without_global
=
attn_probs
[:,
:,
:,
max_num_global_attn_indices
:]
# compute attn output with global
attn_output_without_global
=
self
.
_sliding_chunks_matmul_attn_probs_value
(
attn_probs_without_global
,
value_vectors
,
self
.
_one_sided_attn_window_size
)
return
attn_output_only_global
+
attn_output_without_global
def
_compute_global_attn_output_from_hidden
(
self
,
attn_output
,
hidden_states
,
max_num_global_attn_indices
,
layer_head_mask
,
is_local_index_global_attn_nonzero
,
is_index_global_attn_nonzero
,
is_local_index_no_global_attn_nonzero
,
is_index_masked
,
training
,
):
batch_size
,
seq_len
=
get_shape_list
(
hidden_states
)[:
2
]
# prepare global hidden states
global_attn_hidden_states
=
tf
.
gather_nd
(
hidden_states
,
is_index_global_attn_nonzero
)
global_attn_hidden_states
=
tf
.
scatter_nd
(
is_local_index_global_attn_nonzero
,
global_attn_hidden_states
,
shape
=
(
batch_size
,
max_num_global_attn_indices
,
self
.
_num_heads
*
self
.
_key_dim
),
)
# global key, query, value
global_query_vectors_only_global
=
self
.
_global_query_dense
(
global_attn_hidden_states
)
global_key_vectors
=
self
.
_global_key_dense
(
hidden_states
)
global_value_vectors
=
self
.
_global_value_dense
(
hidden_states
)
# normalize
global_query_vectors_only_global
/=
tf
.
math
.
sqrt
(
tf
.
cast
(
self
.
_key_dim
,
dtype
=
global_query_vectors_only_global
.
dtype
))
global_query_vectors_only_global
=
self
.
reshape_and_transpose
(
global_query_vectors_only_global
,
batch_size
)
global_key_vectors
=
self
.
reshape_and_transpose
(
global_key_vectors
,
batch_size
)
global_value_vectors
=
self
.
reshape_and_transpose
(
global_value_vectors
,
batch_size
)
# compute attn scores
global_attn_scores
=
tf
.
matmul
(
global_query_vectors_only_global
,
global_key_vectors
,
transpose_b
=
True
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
get_shape_list
(
global_attn_scores
),
[
batch_size
*
self
.
_num_heads
,
max_num_global_attn_indices
,
seq_len
],
message
=
f
"global_attn_scores have the wrong size. Size should be"
f
"
{
(
batch_size
*
self
.
_num_heads
,
max_num_global_attn_indices
,
seq_len
)
}
, "
f
"but is
{
get_shape_list
(
global_attn_scores
)
}
."
,
)
global_attn_scores
=
tf
.
reshape
(
global_attn_scores
,
(
batch_size
,
self
.
_num_heads
,
max_num_global_attn_indices
,
seq_len
),
)
global_attn_scores_trans
=
tf
.
transpose
(
global_attn_scores
,
(
0
,
2
,
1
,
3
))
mask_shape
=
(
get_shape_list
(
is_local_index_no_global_attn_nonzero
)[
0
],
)
+
tuple
(
get_shape_list
(
global_attn_scores_trans
)[
-
2
:])
global_attn_mask
=
tf
.
ones
(
mask_shape
)
*
-
10000.0
global_attn_mask
=
tf
.
cast
(
global_attn_mask
,
dtype
=
global_attn_scores_trans
.
dtype
)
# scatter mask
global_attn_scores_trans
=
tf
.
tensor_scatter_nd_update
(
global_attn_scores_trans
,
is_local_index_no_global_attn_nonzero
,
global_attn_mask
,
)
global_attn_scores
=
tf
.
transpose
(
global_attn_scores_trans
,
(
0
,
2
,
1
,
3
))
# mask global attn scores
attn_mask
=
tf
.
tile
(
is_index_masked
[:,
None
,
None
,
:],
(
1
,
get_shape_list
(
global_attn_scores
)[
1
],
1
,
1
))
global_attn_scores
=
tf
.
where
(
attn_mask
,
-
10000.0
,
global_attn_scores
)
global_attn_scores
=
tf
.
reshape
(
global_attn_scores
,
(
batch_size
*
self
.
_num_heads
,
max_num_global_attn_indices
,
seq_len
),
)
# compute global attn probs
global_attn_probs_float
=
tf
.
nn
.
softmax
(
global_attn_scores
,
axis
=-
1
)
# apply layer head masking
if
layer_head_mask
is
not
None
:
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
get_shape_list
(
layer_head_mask
),
[
self
.
_num_heads
],
message
=
f
"Head mask for a single layer should be of size "
f
"
{
(
self
.
_num_heads
)
}
, but is
{
get_shape_list
(
layer_head_mask
)
}
"
,
)
global_attn_probs_float
=
tf
.
reshape
(
layer_head_mask
,
(
1
,
-
1
,
1
,
1
))
*
tf
.
reshape
(
global_attn_probs_float
,
(
batch_size
,
self
.
_num_heads
,
max_num_global_attn_indices
,
seq_len
))
global_attn_probs_float
=
tf
.
reshape
(
global_attn_probs_float
,
(
batch_size
*
self
.
_num_heads
,
max_num_global_attn_indices
,
seq_len
))
# dropout
global_attn_probs
=
self
.
_global_dropout_layer
(
global_attn_probs_float
,
training
=
training
)
# global attn output
global_attn_output
=
tf
.
matmul
(
global_attn_probs
,
global_value_vectors
)
if
tf
.
executing_eagerly
():
tf
.
debugging
.
assert_equal
(
get_shape_list
(
global_attn_output
),
[
batch_size
*
self
.
_num_heads
,
max_num_global_attn_indices
,
self
.
_key_dim
],
message
=
f
"global_attn_output tensor has the wrong size. Size should be "
f
"
{
(
batch_size
*
self
.
_num_heads
,
max_num_global_attn_indices
,
self
.
_key_dim
)
}
, "
f
"but is
{
get_shape_list
(
global_attn_output
)
}
."
,
)
global_attn_output
=
tf
.
reshape
(
global_attn_output
,
(
batch_size
,
self
.
_num_heads
,
max_num_global_attn_indices
,
self
.
_key_dim
),
)
# get only non zero global attn output
nonzero_global_attn_output
=
tf
.
gather_nd
(
tf
.
transpose
(
global_attn_output
,
(
0
,
2
,
1
,
3
)),
is_local_index_global_attn_nonzero
,
)
nonzero_global_attn_output
=
tf
.
reshape
(
nonzero_global_attn_output
,
(
get_shape_list
(
is_local_index_global_attn_nonzero
)[
0
],
-
1
),
)
# overwrite values with global attention
attn_output
=
tf
.
tensor_scatter_nd_update
(
attn_output
,
is_index_global_attn_nonzero
,
nonzero_global_attn_output
)
global_attn_probs
=
tf
.
reshape
(
global_attn_probs
,
(
batch_size
,
self
.
_num_heads
,
max_num_global_attn_indices
,
seq_len
))
attn_output
=
self
.
_output_dense
(
attn_output
)
return
attn_output
,
global_attn_probs
def
reshape_and_transpose
(
self
,
vector
,
batch_size
):
return
tf
.
reshape
(
tf
.
transpose
(
tf
.
reshape
(
vector
,
(
batch_size
,
-
1
,
self
.
_num_heads
,
self
.
_key_dim
)),
(
0
,
2
,
1
,
3
),
),
(
batch_size
*
self
.
_num_heads
,
-
1
,
self
.
_key_dim
),
)
official/projects/longformer/longformer_attention_test.py
0 → 100644
View file @
8b641b13
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.projects.longformer.longformer_attention."""
import
numpy
as
np
import
tensorflow
as
tf
from
official.modeling.tf_utils
import
get_shape_list
from
official.projects.longformer
import
longformer_attention
def
_create_mock_attention_data
(
num_heads
,
key_dim
,
value_dim
,
q_seq_length
,
kv_seq_length
,
batch_size
,
include_mask
=
False
):
"""Creates mock testing data.
Args:
num_heads: `int`, Number of attention heads.
key_dim: `int`, Size of query head.
value_dim: `int`, Size of key, value dim.
q_seq_length: `int`, query sequence length of the input.
kv_seq_length: `int`, key, value sequence length of the input.
batch_size: `int`, the batch size.
include_mask: optional `bool`, whether or not to include mask data.
Returns:
A dictionary with `str` as keys and `Tensor` as values.
"""
query_shape
=
(
batch_size
,
q_seq_length
,
key_dim
)
value_shape
=
(
batch_size
,
kv_seq_length
,
value_dim
)
data
=
dict
(
query
=
tf
.
random
.
normal
(
shape
=
query_shape
),
value
=
tf
.
random
.
normal
(
shape
=
value_shape
),
key
=
tf
.
random
.
normal
(
shape
=
value_shape
))
total_seq_length
=
kv_seq_length
if
include_mask
:
mask_shape
=
(
batch_size
,
num_heads
,
q_seq_length
,
total_seq_length
)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
mask_shape
).
astype
(
'float32'
)
mask_data
=
dict
(
attention_mask
=
mask_data
)
data
.
update
(
mask_data
)
return
data
class
LongformerAttentionTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
(
LongformerAttentionTest
,
self
).
setUp
()
np
.
random
.
seed
(
0
)
tf
.
random
.
set_seed
(
0
)
def
_get_hidden_states
(
self
):
return
tf
.
convert_to_tensor
(
[[
[
4.98332758e-01
,
2.69175139e00
,
-
7.08081422e-03
,
1.04915401e00
,
-
1.83476661e00
,
7.67220476e-01
,
2.98580543e-01
,
2.84803992e-02
,
],
[
-
7.58357372e-01
,
4.20635998e-01
,
-
4.04739919e-02
,
1.59924145e-01
,
2.05135748e00
,
-
1.15997978e00
,
5.37166397e-01
,
2.62873606e-01
,
],
[
-
1.69438001e00
,
4.17574660e-01
,
-
1.49196962e00
,
-
1.76483717e00
,
-
1.94566312e-01
,
-
1.71183858e00
,
7.72903565e-01
,
-
1.11557056e00
,
],
[
5.44028163e-01
,
2.05466114e-01
,
-
3.63045868e-01
,
2.41865062e-01
,
3.20348382e-01
,
-
9.05611176e-01
,
-
1.92690727e-01
,
-
1.19917547e00
,
],
]],
dtype
=
tf
.
float32
,
)
def
test_diagonalize
(
self
):
hidden_states
=
self
.
_get_hidden_states
()
hidden_states
=
tf
.
reshape
(
hidden_states
,
(
1
,
8
,
4
))
# set seq length = 8, hidden dim = 4
chunked_hidden_states
=
longformer_attention
.
LongformerAttention
.
_chunk
(
hidden_states
,
window_overlap
=
2
)
window_overlap_size
=
get_shape_list
(
chunked_hidden_states
)[
2
]
self
.
assertEqual
(
window_overlap_size
,
4
)
padded_hidden_states
=
longformer_attention
.
LongformerAttention
.
_pad_and_diagonalize
(
chunked_hidden_states
)
self
.
assertEqual
(
get_shape_list
(
padded_hidden_states
)[
-
1
],
get_shape_list
(
chunked_hidden_states
)[
-
1
]
+
window_overlap_size
-
1
)
# first row => [0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000]
tf
.
debugging
.
assert_near
(
padded_hidden_states
[
0
,
0
,
0
,
:
4
],
chunked_hidden_states
[
0
,
0
,
0
],
rtol
=
1e-3
)
tf
.
debugging
.
assert_near
(
padded_hidden_states
[
0
,
0
,
0
,
4
:],
tf
.
zeros
((
3
,),
dtype
=
tf
.
dtypes
.
float32
),
rtol
=
1e-3
)
# last row => [0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629]
tf
.
debugging
.
assert_near
(
padded_hidden_states
[
0
,
0
,
-
1
,
3
:],
chunked_hidden_states
[
0
,
0
,
-
1
],
rtol
=
1e-3
)
tf
.
debugging
.
assert_near
(
padded_hidden_states
[
0
,
0
,
-
1
,
:
3
],
tf
.
zeros
((
3
,),
dtype
=
tf
.
dtypes
.
float32
),
rtol
=
1e-3
)
def
test_pad_and_transpose_last_two_dims
(
self
):
hidden_states
=
self
.
_get_hidden_states
()
self
.
assertTrue
(
get_shape_list
(
hidden_states
),
[
1
,
8
,
4
])
# pad along seq length dim
paddings
=
tf
.
constant
([[
0
,
0
],
[
0
,
0
],
[
0
,
1
],
[
0
,
0
]],
dtype
=
tf
.
dtypes
.
int32
)
hidden_states
=
longformer_attention
.
LongformerAttention
.
_chunk
(
hidden_states
,
window_overlap
=
2
)
padded_hidden_states
=
longformer_attention
.
LongformerAttention
.
_pad_and_transpose_last_two_dims
(
hidden_states
,
paddings
)
self
.
assertEqual
(
get_shape_list
(
padded_hidden_states
),
[
1
,
1
,
8
,
5
])
expected_added_dim
=
tf
.
zeros
((
5
,),
dtype
=
tf
.
dtypes
.
float32
)
tf
.
debugging
.
assert_near
(
expected_added_dim
,
padded_hidden_states
[
0
,
0
,
-
1
,
:],
rtol
=
1e-6
)
tf
.
debugging
.
assert_near
(
hidden_states
[
0
,
0
,
-
1
,
:],
tf
.
reshape
(
padded_hidden_states
,
(
1
,
-
1
))[
0
,
24
:
32
],
rtol
=
1e-6
)
def
test_mask_invalid_locations
(
self
):
hidden_states
=
self
.
_get_hidden_states
()
batch_size
=
1
seq_length
=
8
hidden_size
=
4
hidden_states
=
tf
.
reshape
(
hidden_states
,
(
batch_size
,
seq_length
,
hidden_size
))
hidden_states
=
longformer_attention
.
LongformerAttention
.
_chunk
(
hidden_states
,
window_overlap
=
2
)
hid_states_1
=
longformer_attention
.
LongformerAttention
.
_mask_invalid_locations
(
hidden_states
,
1
)
hid_states_2
=
longformer_attention
.
LongformerAttention
.
_mask_invalid_locations
(
hidden_states
,
2
)
hid_states_3
=
longformer_attention
.
LongformerAttention
.
_mask_invalid_locations
(
hidden_states
[:,
:,
:,
:
3
],
2
)
hid_states_4
=
longformer_attention
.
LongformerAttention
.
_mask_invalid_locations
(
hidden_states
[:,
:,
2
:,
:],
2
)
self
.
assertEqual
(
tf
.
math
.
reduce_sum
(
tf
.
cast
(
tf
.
math
.
is_inf
(
hid_states_1
),
tf
.
dtypes
.
int32
)),
8
)
self
.
assertEqual
(
tf
.
math
.
reduce_sum
(
tf
.
cast
(
tf
.
math
.
is_inf
(
hid_states_2
),
tf
.
dtypes
.
int32
)),
24
)
self
.
assertEqual
(
tf
.
math
.
reduce_sum
(
tf
.
cast
(
tf
.
math
.
is_inf
(
hid_states_3
),
tf
.
dtypes
.
int32
)),
24
)
self
.
assertEqual
(
tf
.
math
.
reduce_sum
(
tf
.
cast
(
tf
.
math
.
is_inf
(
hid_states_4
),
tf
.
dtypes
.
int32
)),
12
)
def
test_chunk
(
self
):
hidden_states
=
self
.
_get_hidden_states
()
batch_size
=
1
seq_length
=
8
hidden_size
=
4
hidden_states
=
tf
.
reshape
(
hidden_states
,
(
batch_size
,
seq_length
,
hidden_size
))
chunked_hidden_states
=
longformer_attention
.
LongformerAttention
.
_chunk
(
hidden_states
,
window_overlap
=
2
)
# expected slices across chunk and seq length dim
expected_slice_along_seq_length
=
tf
.
convert_to_tensor
(
[
0.4983
,
-
0.7584
,
-
1.6944
],
dtype
=
tf
.
dtypes
.
float32
)
expected_slice_along_chunk
=
tf
.
convert_to_tensor
(
[
0.4983
,
-
1.8348
,
-
0.7584
,
2.0514
],
dtype
=
tf
.
dtypes
.
float32
)
self
.
assertEqual
(
get_shape_list
(
chunked_hidden_states
),
[
1
,
3
,
4
,
4
])
tf
.
debugging
.
assert_near
(
chunked_hidden_states
[
0
,
:,
0
,
0
],
expected_slice_along_seq_length
,
rtol
=
1e-3
)
tf
.
debugging
.
assert_near
(
chunked_hidden_states
[
0
,
0
,
:,
0
],
expected_slice_along_chunk
,
rtol
=
1e-3
)
def
test_layer_local_attn
(
self
):
hidden_states
=
self
.
_get_hidden_states
()
batch_size
,
seq_length
,
_
=
hidden_states
.
shape
layer
=
longformer_attention
.
LongformerAttention
(
num_heads
=
2
,
key_dim
=
4
,
value_dim
=
4
,
layer_id
=
0
,
attention_window
=
4
,
global_attention_size
=
0
,
)
attention_mask
=
tf
.
zeros
((
batch_size
,
seq_length
),
dtype
=
tf
.
dtypes
.
float32
)
is_index_global_attn
=
tf
.
math
.
greater
(
attention_mask
,
1
)
attention_mask
=
tf
.
where
(
tf
.
range
(
4
)[
None
,
:,
None
,
None
]
>
1
,
-
10000.0
,
attention_mask
[:,
:,
None
,
None
])
is_index_masked
=
tf
.
math
.
less
(
attention_mask
[:,
:,
0
,
0
],
0
)
output_hidden_states
=
layer
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
is_index_masked
=
is_index_masked
,
is_index_global_attn
=
is_index_global_attn
,
)[
0
]
self
.
assertTrue
(
output_hidden_states
.
shape
,
(
1
,
4
,
8
))
def
test_layer_global_attn
(
self
):
layer
=
longformer_attention
.
LongformerAttention
(
num_heads
=
2
,
key_dim
=
4
,
value_dim
=
4
,
layer_id
=
0
,
attention_window
=
4
,
global_attention_size
=
1
,
)
hidden_states
=
self
.
_get_hidden_states
()
hidden_states
=
tf
.
concat
(
[
self
.
_get_hidden_states
(),
self
.
_get_hidden_states
()
-
0.5
],
axis
=
0
)
_
,
seq_length
,
_
=
hidden_states
.
shape
# create attn mask
attention_mask_1
=
tf
.
zeros
((
1
,
1
,
1
,
seq_length
),
dtype
=
tf
.
dtypes
.
float32
)
attention_mask_2
=
tf
.
zeros
((
1
,
1
,
1
,
seq_length
),
dtype
=
tf
.
dtypes
.
float32
)
attention_mask_1
=
tf
.
where
(
tf
.
range
(
4
)[
None
,
:,
None
,
None
]
==
0
,
10000.0
,
attention_mask_1
)
attention_mask_1
=
tf
.
where
(
tf
.
range
(
4
)[
None
,
:,
None
,
None
]
>
2
,
-
10000.0
,
attention_mask_1
)
attention_mask_2
=
tf
.
where
(
tf
.
range
(
4
)[
None
,
:,
None
,
None
]
==
0
,
10000.0
,
attention_mask_2
)
attention_mask
=
tf
.
concat
([
attention_mask_1
,
attention_mask_2
],
axis
=
0
)
is_index_masked
=
tf
.
math
.
less
(
attention_mask
[:,
:,
0
,
0
],
0
)
is_index_global_attn
=
tf
.
math
.
greater
(
attention_mask
[:,
:,
0
,
0
],
0
)
output_hidden_states
=
layer
(
hidden_states
=
hidden_states
,
attention_mask
=-
tf
.
math
.
abs
(
attention_mask
),
is_index_masked
=
is_index_masked
,
is_index_global_attn
=
is_index_global_attn
,
)[
0
]
self
.
assertTrue
(
output_hidden_states
.
shape
,
(
2
,
4
,
8
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/projects/longformer/longformer_encoder.py
0 → 100644
View file @
8b641b13
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Longformer encoder. Modified From huggingface/transformers."""
# pylint: disable=g-classes-have-attributes
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Union
from
absl
import
logging
import
tensorflow
as
tf
from
official.modeling.tf_utils
import
get_shape_list
from
official.nlp.modeling
import
layers
from
official.projects.longformer.longformer_encoder_block
import
LongformerEncoderBlock
_Initializer
=
Union
[
str
,
tf
.
keras
.
initializers
.
Initializer
]
_approx_gelu
=
lambda
x
:
tf
.
keras
.
activations
.
gelu
(
x
,
approximate
=
True
)
class
LongformerEncoder
(
tf
.
keras
.
layers
.
Layer
):
"""LongformerEncoder.
Args:
vocab_size: The size of the token vocabulary.
attention_window: list of ints representing the window size for each layer.
global_attention_size: the size of global attention used for each token.
pad_token_id: the token id for the pad token
hidden_size: The size of the transformer hidden layers.
num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads.
max_sequence_length: The maximum sequence length that this encoder can
consume. If None, max_sequence_length uses the value from sequence length.
This determines the variable shape for positional embeddings.
type_vocab_size: The number of types that the 'type_ids' input can take.
inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network for each transformer.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network for each transformer.
output_dropout: Dropout probability for the post-attention and output
dropout.
attention_dropout: The dropout rate to use for the attention layers within
the transformer layers.
initializer: The initialzer to use for all weights in this encoder.
output_range: The sequence output range, [0, output_range), by slicing the
target sequence of the last transformer layer. `None` means the entire
target sequence will attend to the source sequence, which yields the full
output.
embedding_width: The width of the word embeddings. If the embedding width is
not equal to hidden size, embedding parameters will be factorized into two
matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size').
embedding_layer: An optional Layer instance which will be called to generate
embeddings for the input word IDs.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
"""
def
__init__
(
self
,
vocab_size
:
int
,
attention_window
:
Union
[
List
[
int
],
int
]
=
512
,
global_attention_size
:
int
=
0
,
pad_token_id
:
int
=
1
,
hidden_size
:
int
=
768
,
num_layers
:
int
=
12
,
num_attention_heads
:
int
=
12
,
max_sequence_length
:
int
=
512
,
type_vocab_size
:
int
=
16
,
inner_dim
:
int
=
3072
,
inner_activation
:
Callable
[...,
Any
]
=
_approx_gelu
,
output_dropout
:
float
=
0.1
,
attention_dropout
:
float
=
0.1
,
initializer
:
_Initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
output_range
:
Optional
[
int
]
=
None
,
embedding_width
:
Optional
[
int
]
=
None
,
embedding_layer
:
Optional
[
tf
.
keras
.
layers
.
Layer
]
=
None
,
norm_first
:
bool
=
False
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
# Longformer args
self
.
_attention_window
=
attention_window
self
.
_global_attention_size
=
global_attention_size
self
.
_pad_token_id
=
pad_token_id
activation
=
tf
.
keras
.
activations
.
get
(
inner_activation
)
initializer
=
tf
.
keras
.
initializers
.
get
(
initializer
)
if
embedding_width
is
None
:
embedding_width
=
hidden_size
if
embedding_layer
is
None
:
self
.
_embedding_layer
=
layers
.
OnDeviceEmbedding
(
vocab_size
=
vocab_size
,
embedding_width
=
embedding_width
,
initializer
=
initializer
,
name
=
'word_embeddings'
)
else
:
self
.
_embedding_layer
=
embedding_layer
self
.
_position_embedding_layer
=
layers
.
PositionEmbedding
(
initializer
=
initializer
,
max_length
=
max_sequence_length
,
name
=
'position_embedding'
)
self
.
_type_embedding_layer
=
layers
.
OnDeviceEmbedding
(
vocab_size
=
type_vocab_size
,
embedding_width
=
embedding_width
,
initializer
=
initializer
,
use_one_hot
=
True
,
name
=
'type_embeddings'
)
self
.
_embedding_norm_layer
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
'embeddings/layer_norm'
,
axis
=-
1
,
epsilon
=
1e-12
,
dtype
=
tf
.
float32
)
self
.
_embedding_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
output_dropout
,
name
=
'embedding_dropout'
)
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
self
.
_embedding_projection
=
None
if
embedding_width
!=
hidden_size
:
self
.
_embedding_projection
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
'...x,xy->...y'
,
output_shape
=
hidden_size
,
bias_axes
=
'y'
,
kernel_initializer
=
initializer
,
name
=
'embedding_projection'
)
self
.
_transformer_layers
=
[]
self
.
_attention_mask_layer
=
layers
.
SelfAttentionMask
(
name
=
'self_attention_mask'
)
for
i
in
range
(
num_layers
):
layer
=
LongformerEncoderBlock
(
global_attention_size
=
global_attention_size
,
num_attention_heads
=
num_attention_heads
,
inner_dim
=
inner_dim
,
inner_activation
=
inner_activation
,
attention_window
=
attention_window
[
i
],
layer_id
=
i
,
output_dropout
=
output_dropout
,
attention_dropout
=
attention_dropout
,
norm_first
=
norm_first
,
output_range
=
output_range
if
i
==
num_layers
-
1
else
None
,
kernel_initializer
=
initializer
,
name
=
f
'transformer/layer_
{
i
}
'
)
self
.
_transformer_layers
.
append
(
layer
)
self
.
_pooler_layer
=
tf
.
keras
.
layers
.
Dense
(
units
=
hidden_size
,
activation
=
'tanh'
,
kernel_initializer
=
initializer
,
name
=
'pooler_transform'
)
self
.
_config
=
{
'vocab_size'
:
vocab_size
,
'hidden_size'
:
hidden_size
,
'num_layers'
:
num_layers
,
'num_attention_heads'
:
num_attention_heads
,
'max_sequence_length'
:
max_sequence_length
,
'type_vocab_size'
:
type_vocab_size
,
'inner_dim'
:
inner_dim
,
'inner_activation'
:
tf
.
keras
.
activations
.
serialize
(
activation
),
'output_dropout'
:
output_dropout
,
'attention_dropout'
:
attention_dropout
,
'initializer'
:
tf
.
keras
.
initializers
.
serialize
(
initializer
),
'output_range'
:
output_range
,
'embedding_width'
:
embedding_width
,
'embedding_layer'
:
embedding_layer
,
'norm_first'
:
norm_first
,
'attention_window'
:
attention_window
,
'global_attention_size'
:
global_attention_size
,
'pad_token_id'
:
pad_token_id
,
}
self
.
inputs
=
dict
(
input_word_ids
=
tf
.
keras
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
),
input_mask
=
tf
.
keras
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
),
input_type_ids
=
tf
.
keras
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
))
def
call
(
self
,
inputs
):
word_embeddings
=
None
if
isinstance
(
inputs
,
dict
):
word_ids
=
inputs
.
get
(
'input_word_ids'
)
# input_ids
mask
=
inputs
.
get
(
'input_mask'
)
# attention_mask
type_ids
=
inputs
.
get
(
'input_type_ids'
)
# token_type_ids
word_embeddings
=
inputs
.
get
(
'input_word_embeddings'
,
None
)
# input_embeds
else
:
raise
ValueError
(
f
'Unexpected inputs type to
{
self
.
__class__
}
.'
)
(
padding_len
,
word_ids
,
mask
,
type_ids
,
word_embeddings
,
)
=
self
.
_pad_to_window_size
(
word_ids
=
word_ids
,
mask
=
mask
,
type_ids
=
type_ids
,
word_embeddings
=
word_embeddings
,
pad_token_id
=
self
.
_pad_token_id
)
if
word_embeddings
is
None
:
word_embeddings
=
self
.
_embedding_layer
(
word_ids
)
# absolute position embeddings.
position_embeddings
=
self
.
_position_embedding_layer
(
word_embeddings
)
type_embeddings
=
self
.
_type_embedding_layer
(
type_ids
)
embeddings
=
word_embeddings
+
position_embeddings
+
type_embeddings
embeddings
=
self
.
_embedding_norm_layer
(
embeddings
)
embeddings
=
self
.
_embedding_dropout
(
embeddings
)
if
self
.
_embedding_projection
is
not
None
:
embeddings
=
self
.
_embedding_projection
(
embeddings
)
batch_size
,
seq_len
=
get_shape_list
(
mask
)
# create masks with fixed len global_attention_size
mask
=
tf
.
transpose
(
tf
.
concat
(
values
=
[
tf
.
ones
(
(
self
.
_global_attention_size
,
batch_size
),
tf
.
int32
)
*
2
,
tf
.
transpose
(
mask
)[
self
.
_global_attention_size
:]
],
axis
=
0
))
is_index_masked
=
tf
.
math
.
less
(
mask
,
1
)
is_index_global_attn
=
tf
.
transpose
(
tf
.
concat
(
values
=
[
tf
.
ones
((
self
.
_global_attention_size
,
batch_size
),
tf
.
bool
),
tf
.
zeros
((
seq_len
-
self
.
_global_attention_size
,
batch_size
),
tf
.
bool
)
],
axis
=
0
))
# Longformer
attention_mask
=
mask
extended_attention_mask
=
tf
.
reshape
(
attention_mask
,
(
tf
.
shape
(
mask
)[
0
],
tf
.
shape
(
mask
)[
1
],
1
,
1
))
attention_mask
=
tf
.
cast
(
tf
.
math
.
abs
(
1
-
extended_attention_mask
),
tf
.
dtypes
.
float32
)
*
-
10000.0
encoder_outputs
=
[]
x
=
embeddings
# TFLongformerEncoder
for
layer
in
self
.
_transformer_layers
:
x
=
layer
([
x
,
attention_mask
,
is_index_masked
,
is_index_global_attn
])
encoder_outputs
.
append
(
x
)
last_encoder_output
=
encoder_outputs
[
-
1
]
if
padding_len
>
0
:
last_encoder_output
=
last_encoder_output
[:,
:
-
padding_len
]
first_token_tensor
=
last_encoder_output
[:,
0
,
:]
pooled_output
=
self
.
_pooler_layer
(
first_token_tensor
)
return
dict
(
sequence_output
=
last_encoder_output
,
pooled_output
=
pooled_output
,
encoder_outputs
=
encoder_outputs
)
def
get_embedding_table
(
self
):
return
self
.
_embedding_layer
.
embeddings
def
get_embedding_layer
(
self
):
return
self
.
_embedding_layer
def
get_config
(
self
):
return
dict
(
self
.
_config
)
@
property
def
transformer_layers
(
self
):
"""List of Transformer layers in the encoder."""
return
self
.
_transformer_layers
@
property
def
pooler_layer
(
self
):
"""The pooler dense layer after the transformer layers."""
return
self
.
_pooler_layer
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
if
'embedding_layer'
in
config
and
config
[
'embedding_layer'
]
is
not
None
:
warn_string
=
(
'You are reloading a model that was saved with a '
'potentially-shared embedding layer object. If you contine to '
'train this model, the embedding layer will no longer be shared. '
'To work around this, load the model outside of the Keras API.'
)
print
(
'WARNING: '
+
warn_string
)
logging
.
warn
(
warn_string
)
return
cls
(
**
config
)
def
_pad_to_window_size
(
self
,
word_ids
,
mask
,
type_ids
,
word_embeddings
,
pad_token_id
,
):
# padding
attention_window
=
max
(
self
.
_attention_window
)
assert
(
attention_window
%
2
==
0
),
(
'`attention_window` should be an even value.'
f
'Given
{
attention_window
}
'
)
input_shape
=
get_shape_list
(
word_ids
)
if
word_ids
is
not
None
else
get_shape_list
(
word_embeddings
)
batch_size
,
seq_len
=
input_shape
[:
2
]
if
seq_len
is
not
None
:
padding_len
=
(
attention_window
-
seq_len
%
attention_window
)
%
attention_window
else
:
padding_len
=
0
paddings
=
tf
.
convert_to_tensor
([[
0
,
0
],
[
0
,
padding_len
]])
if
word_ids
is
not
None
:
word_ids
=
tf
.
pad
(
word_ids
,
paddings
,
constant_values
=
pad_token_id
)
if
word_embeddings
is
not
None
:
def
pad_embeddings
():
word_ids_padding
=
tf
.
fill
((
batch_size
,
padding_len
),
self
.
pad_token_id
)
word_embeddings_padding
=
self
.
_embedding_layer
(
word_ids_padding
)
return
tf
.
concat
([
word_embeddings
,
word_embeddings_padding
],
axis
=-
2
)
word_embeddings
=
tf
.
cond
(
tf
.
math
.
greater
(
padding_len
,
0
),
pad_embeddings
,
lambda
:
word_embeddings
)
mask
=
tf
.
pad
(
mask
,
paddings
,
constant_values
=
False
)
# no attention on the padding tokens
token_type_ids
=
tf
.
pad
(
type_ids
,
paddings
,
constant_values
=
0
)
# pad with token_type_id = 0
return
(
padding_len
,
word_ids
,
mask
,
token_type_ids
,
word_embeddings
,
)
official/projects/longformer/longformer_encoder_block.py
0 → 100644
View file @
8b641b13
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Longformer attention layer. Modified From huggingface/transformers."""
import
tensorflow
as
tf
from
official.projects.longformer.longformer_attention
import
LongformerAttention
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
class
LongformerEncoderBlock
(
tf
.
keras
.
layers
.
Layer
):
"""LongformerEncoderBlock.
Args:
num_attention_heads: Number of attention heads.
inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network.
output_range: the sequence output range, [0, output_range) for slicing the
target sequence. `None` means the target sequence is not sliced.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
output_dropout: Dropout probability for the post-attention and output
dropout.
attention_dropout: Dropout probability for within the attention layer.
inner_dropout: Dropout probability for the first Dense layer in a
two-layer feedforward network.
attention_initializer: Initializer for kernels of attention layers. If set
`None`, attention layers use kernel_initializer as initializer for
kernel.
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
**kwargs: keyword arguments/
"""
def
__init__
(
self
,
global_attention_size
,
num_attention_heads
,
inner_dim
,
inner_activation
,
# Longformer
attention_window
,
layer_id
=
0
,
output_range
=
None
,
kernel_initializer
=
"glorot_uniform"
,
bias_initializer
=
"zeros"
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
activity_regularizer
=
None
,
kernel_constraint
=
None
,
bias_constraint
=
None
,
use_bias
=
True
,
norm_first
=
False
,
norm_epsilon
=
1e-12
,
output_dropout
=
0.0
,
attention_dropout
=
0.0
,
inner_dropout
=
0.0
,
attention_initializer
=
None
,
attention_axes
=
None
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
global_attention_size
=
global_attention_size
self
.
_num_heads
=
num_attention_heads
self
.
_inner_dim
=
inner_dim
self
.
_inner_activation
=
inner_activation
# Longformer
self
.
_attention_window
=
attention_window
self
.
_layer_id
=
layer_id
self
.
_attention_dropout
=
attention_dropout
self
.
_attention_dropout_rate
=
attention_dropout
self
.
_output_dropout
=
output_dropout
self
.
_output_dropout_rate
=
output_dropout
self
.
_output_range
=
output_range
self
.
_kernel_initializer
=
tf
.
keras
.
initializers
.
get
(
kernel_initializer
)
self
.
_bias_initializer
=
tf
.
keras
.
initializers
.
get
(
bias_initializer
)
self
.
_kernel_regularizer
=
tf
.
keras
.
regularizers
.
get
(
kernel_regularizer
)
self
.
_bias_regularizer
=
tf
.
keras
.
regularizers
.
get
(
bias_regularizer
)
self
.
_activity_regularizer
=
tf
.
keras
.
regularizers
.
get
(
activity_regularizer
)
self
.
_kernel_constraint
=
tf
.
keras
.
constraints
.
get
(
kernel_constraint
)
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
self
.
_use_bias
=
use_bias
self
.
_norm_first
=
norm_first
self
.
_norm_epsilon
=
norm_epsilon
self
.
_inner_dropout
=
inner_dropout
if
attention_initializer
:
self
.
_attention_initializer
=
tf
.
keras
.
initializers
.
get
(
attention_initializer
)
else
:
self
.
_attention_initializer
=
self
.
_kernel_initializer
self
.
_attention_axes
=
attention_axes
def
build
(
self
,
input_shape
):
if
isinstance
(
input_shape
,
tf
.
TensorShape
):
input_tensor_shape
=
input_shape
elif
isinstance
(
input_shape
,
(
list
,
tuple
)):
input_tensor_shape
=
tf
.
TensorShape
(
input_shape
[
0
])
else
:
raise
ValueError
(
f
"The type of input shape argument is not supported, got: "
f
"
{
type
(
input_shape
)
}
"
)
einsum_equation
=
"abc,cd->abd"
if
len
(
input_tensor_shape
.
as_list
())
>
3
:
einsum_equation
=
"...bc,cd->...bd"
hidden_size
=
input_tensor_shape
[
-
1
]
if
hidden_size
%
self
.
_num_heads
!=
0
:
raise
ValueError
(
f
"The input size (
{
hidden_size
}
) is not a multiple of the number of attention "
f
"heads (
{
self
.
_num_heads
}
)"
)
self
.
_attention_head_size
=
int
(
hidden_size
//
self
.
_num_heads
)
common_kwargs
=
dict
(
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
)
# TFLongformerSelfAttention + TFLongformerSelfOutput.dense
self
.
_attention_layer
=
LongformerAttention
(
# Longformer
layer_id
=
self
.
_layer_id
,
global_attention_size
=
self
.
global_attention_size
,
attention_window
=
self
.
_attention_window
,
num_heads
=
self
.
_num_heads
,
key_dim
=
self
.
_attention_head_size
,
dropout
=
self
.
_attention_dropout
,
use_bias
=
self
.
_use_bias
,
kernel_initializer
=
self
.
_attention_initializer
,
attention_axes
=
self
.
_attention_axes
,
name
=
"self_attention"
,
**
common_kwargs
)
# TFLongformerSelfOutput.dropout
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_output_dropout
)
# Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet.
# TFLongformerSelfOutput.Layernorm
self
.
_attention_layer_norm
=
(
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"self_attention_layer_norm"
,
axis
=-
1
,
epsilon
=
self
.
_norm_epsilon
,
dtype
=
tf
.
float32
))
# TFLongformerIntermediate
# TFLongformerIntermediate.dense
self
.
_intermediate_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
einsum_equation
,
output_shape
=
(
None
,
self
.
_inner_dim
),
bias_axes
=
"d"
,
kernel_initializer
=
self
.
_kernel_initializer
,
name
=
"intermediate"
,
**
common_kwargs
)
policy
=
tf
.
keras
.
mixed_precision
.
global_policy
()
if
policy
.
name
==
"mixed_bfloat16"
:
# bfloat16 causes BERT with the LAMB optimizer to not converge
# as well, so we use float32.
# TODO(b/154538392): Investigate this.
policy
=
tf
.
float32
# TFLongformerIntermediate.intermediate_act_fn
self
.
_intermediate_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
self
.
_inner_activation
,
dtype
=
policy
)
self
.
_inner_dropout_layer
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_inner_dropout
)
# TFLongformerOutput
# TFLongformerOutput.dense
self
.
_output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
einsum_equation
,
output_shape
=
(
None
,
hidden_size
),
bias_axes
=
"d"
,
name
=
"output"
,
kernel_initializer
=
self
.
_kernel_initializer
,
**
common_kwargs
)
# TFLongformerOutput.dropout
self
.
_output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_output_dropout
)
# Use float32 in layernorm for numeric stability.
# TFLongformerOutput.layernorm
self
.
_output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"output_layer_norm"
,
axis
=-
1
,
epsilon
=
self
.
_norm_epsilon
,
dtype
=
tf
.
float32
)
super
().
build
(
input_shape
)
def
get_config
(
self
):
config
=
{
"num_attention_heads"
:
self
.
_num_heads
,
"inner_dim"
:
self
.
_inner_dim
,
"inner_activation"
:
self
.
_inner_activation
,
"output_dropout"
:
self
.
_output_dropout_rate
,
"attention_dropout"
:
self
.
_attention_dropout_rate
,
"output_range"
:
self
.
_output_range
,
"kernel_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_kernel_initializer
),
"bias_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_bias_initializer
),
"kernel_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_kernel_regularizer
),
"bias_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_bias_regularizer
),
"activity_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_activity_regularizer
),
"kernel_constraint"
:
tf
.
keras
.
constraints
.
serialize
(
self
.
_kernel_constraint
),
"bias_constraint"
:
tf
.
keras
.
constraints
.
serialize
(
self
.
_bias_constraint
),
"use_bias"
:
self
.
_use_bias
,
"norm_first"
:
self
.
_norm_first
,
"norm_epsilon"
:
self
.
_norm_epsilon
,
"inner_dropout"
:
self
.
_inner_dropout
,
"attention_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_attention_initializer
),
"attention_axes"
:
self
.
_attention_axes
,
}
base_config
=
super
().
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
inputs
):
"""Transformer self-attention encoder block call.
Args:
inputs: a single tensor or a list of tensors. `input tensor` as the single
sequence of embeddings. [`input tensor`, `attention mask`] to have the
additional attention mask. [`query tensor`, `key value tensor`,
`attention mask`] to have separate input streams for the query, and
key/value to the multi-head attention.
Returns:
An output tensor with the same dimensions as input/query tensor.
"""
if
isinstance
(
inputs
,
(
list
,
tuple
)):
if
len
(
inputs
)
==
4
:
(
input_tensor
,
attention_mask
,
is_index_masked
,
is_index_global_attn
,
)
=
inputs
key_value
=
None
elif
len
(
inputs
)
==
5
:
assert
False
# No key_value
else
:
raise
ValueError
(
f
"Unexpected inputs to
{
self
.
__class__
}
with length at
{
len
(
inputs
)
}
"
)
else
:
input_tensor
=
inputs
attention_mask
=
None
is_index_masked
=
None
is_index_global_attn
=
None
key_value
=
None
if
self
.
_output_range
:
if
self
.
_norm_first
:
source_tensor
=
input_tensor
[:,
0
:
self
.
_output_range
,
:]
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
)
if
key_value
is
not
None
:
key_value
=
self
.
_attention_layer_norm
(
key_value
)
target_tensor
=
input_tensor
[:,
0
:
self
.
_output_range
,
:]
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
0
:
self
.
_output_range
,
:]
if
is_index_masked
is
not
None
:
is_index_masked
=
is_index_masked
[:,
0
:
self
.
_output_range
]
if
is_index_global_attn
is
not
None
:
is_index_global_attn
=
is_index_global_attn
[:,
0
:
self
.
_output_range
]
else
:
if
self
.
_norm_first
:
source_tensor
=
input_tensor
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
)
if
key_value
is
not
None
:
key_value
=
self
.
_attention_layer_norm
(
key_value
)
target_tensor
=
input_tensor
if
key_value
is
None
:
key_value
=
input_tensor
attention_output
=
self
.
_attention_layer
(
hidden_states
=
target_tensor
,
attention_mask
=
attention_mask
,
is_index_masked
=
is_index_masked
,
is_index_global_attn
=
is_index_global_attn
,
)
# TFLongformerAttention.TFLongformerSelfOutput.* - {.dense}
attention_output
=
self
.
_attention_dropout
(
attention_output
)
if
self
.
_norm_first
:
attention_output
=
source_tensor
+
attention_output
else
:
attention_output
=
self
.
_attention_layer_norm
(
target_tensor
+
attention_output
)
if
self
.
_norm_first
:
source_attention_output
=
attention_output
attention_output
=
self
.
_output_layer_norm
(
attention_output
)
# TFLongformerIntermediate
inner_output
=
self
.
_intermediate_dense
(
attention_output
)
inner_output
=
self
.
_intermediate_activation_layer
(
inner_output
)
inner_output
=
self
.
_inner_dropout_layer
(
inner_output
)
# TFLongformerOutput
layer_output
=
self
.
_output_dense
(
inner_output
)
layer_output
=
self
.
_output_dropout
(
layer_output
)
if
self
.
_norm_first
:
return
source_attention_output
+
layer_output
# During mixed precision training, layer norm output is always fp32 for now.
# Casts fp32 for the subsequent add.
layer_output
=
tf
.
cast
(
layer_output
,
tf
.
float32
)
return
self
.
_output_layer_norm
(
layer_output
+
attention_output
)
official/projects/longformer/longformer_encoder_test.py
0 → 100644
View file @
8b641b13
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.projects.longformer.longformer_encoder."""
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
official.projects.longformer.longformer_encoder
import
LongformerEncoder
class
LongformerEncoderTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
(
LongformerEncoderTest
,
self
).
setUp
()
np
.
random
.
seed
(
0
)
tf
.
random
.
set_seed
(
0
)
@
combinations
.
generate
(
combinations
.
combine
(
attention_window
=
[
32
,
128
],
global_attention_size
=
[
0
,
1
,
2
]))
def
test_encoder
(
self
,
attention_window
,
global_attention_size
):
sequence_length
=
128
batch_size
=
2
vocab_size
=
1024
hidden_size
=
256
network
=
LongformerEncoder
(
global_attention_size
=
global_attention_size
,
vocab_size
=
vocab_size
,
attention_window
=
[
attention_window
],
hidden_size
=
hidden_size
,
num_layers
=
1
,
num_attention_heads
=
4
,
max_sequence_length
=
512
)
word_id_data
=
np
.
random
.
randint
(
vocab_size
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
type_id_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
inputs
=
{
'input_word_ids'
:
word_id_data
,
'input_mask'
:
mask_data
,
'input_type_ids'
:
type_id_data
,
}
outputs
=
network
(
inputs
)
self
.
assertEqual
(
outputs
[
'sequence_output'
].
shape
,
(
batch_size
,
sequence_length
,
hidden_size
))
@
combinations
.
generate
(
combinations
.
combine
(
norm_first
=
[
True
,
False
],
global_attention_size
=
[
0
,
1
,
2
]))
def
test_norm_first
(
self
,
norm_first
,
global_attention_size
):
sequence_length
=
128
batch_size
=
2
vocab_size
=
1024
hidden_size
=
256
network
=
LongformerEncoder
(
global_attention_size
=
global_attention_size
,
vocab_size
=
vocab_size
,
attention_window
=
[
32
],
hidden_size
=
hidden_size
,
num_layers
=
1
,
num_attention_heads
=
4
,
max_sequence_length
=
512
,
norm_first
=
norm_first
)
word_id_data
=
np
.
random
.
randint
(
vocab_size
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
type_id_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
inputs
=
{
'input_word_ids'
:
word_id_data
,
'input_mask'
:
mask_data
,
'input_type_ids'
:
type_id_data
,
}
outputs
=
network
(
inputs
)
self
.
assertEqual
(
outputs
[
'sequence_output'
].
shape
,
(
batch_size
,
sequence_length
,
hidden_size
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/projects/longformer/longformer_experiments.py
0 → 100644
View file @
8b641b13
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Longformer experiments."""
# pylint: disable=g-doc-return-or-yield,line-too-long
import
dataclasses
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.modeling
import
optimization
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.data
import
pretrain_dataloader
from
official.nlp.data
import
sentence_prediction_dataloader
from
official.nlp.tasks
import
masked_lm
from
official.nlp.tasks
import
sentence_prediction
from
official.projects.longformer.longformer
import
LongformerEncoderConfig
AdamWeightDecay
=
optimization
.
AdamWeightDecayConfig
PolynomialLr
=
optimization
.
PolynomialLrConfig
PolynomialWarmupConfig
=
optimization
.
PolynomialWarmupConfig
@
dataclasses
.
dataclass
class
LongformerOptimizationConfig
(
optimization
.
OptimizationConfig
):
"""Longformer optimization configuration."""
optimizer
:
optimization
.
OptimizerConfig
=
optimization
.
OptimizerConfig
(
type
=
'adamw'
,
adamw
=
AdamWeightDecay
(
weight_decay_rate
=
0.01
,
exclude_from_weight_decay
=
[
'LayerNorm'
,
'layer_norm'
,
'bias'
],
epsilon
=
1e-6
))
learning_rate
:
optimization
.
LrConfig
=
optimization
.
LrConfig
(
type
=
'polynomial'
,
polynomial
=
PolynomialLr
(
initial_learning_rate
=
1e-4
,
decay_steps
=
1000000
,
end_learning_rate
=
0.0
))
warmup
:
optimization
.
WarmupConfig
=
optimization
.
WarmupConfig
(
type
=
'polynomial'
,
polynomial
=
PolynomialWarmupConfig
(
warmup_steps
=
10000
))
@
exp_factory
.
register_config_factory
(
'longformer/pretraining'
)
def
longformer_pretraining
()
->
cfg
.
ExperimentConfig
:
"""Longformer pretraining experiment."""
config
=
cfg
.
ExperimentConfig
(
runtime
=
cfg
.
RuntimeConfig
(
enable_xla
=
True
),
task
=
masked_lm
.
MaskedLMConfig
(
model
=
bert
.
PretrainerConfig
(
encoder
=
encoders
.
EncoderConfig
(
type
=
'any'
,
any
=
LongformerEncoderConfig
()),
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
768
,
num_classes
=
2
,
dropout_rate
=
0.1
,
name
=
'next_sentence'
)
]),
train_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
use_v2_feature_names
=
True
),
validation_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
use_v2_feature_names
=
True
,
is_training
=
False
)),
trainer
=
cfg
.
TrainerConfig
(
optimizer_config
=
LongformerOptimizationConfig
(),
train_steps
=
1000000
),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
@
exp_factory
.
register_config_factory
(
'longformer/glue'
)
def
longformer_glue
()
->
cfg
.
ExperimentConfig
:
"""Longformer glue fine-tuning."""
config
=
cfg
.
ExperimentConfig
(
task
=
sentence_prediction
.
SentencePredictionConfig
(
model
=
sentence_prediction
.
ModelConfig
(
encoder
=
encoders
.
EncoderConfig
(
type
=
'any'
,
any
=
LongformerEncoderConfig
())),
train_data
=
sentence_prediction_dataloader
.
SentencePredictionDataConfig
(),
validation_data
=
sentence_prediction_dataloader
.
SentencePredictionDataConfig
(
is_training
=
False
,
drop_remainder
=
False
)),
trainer
=
cfg
.
TrainerConfig
(
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
'adamw'
:
{
'weight_decay_rate'
:
0.01
,
'exclude_from_weight_decay'
:
[
'LayerNorm'
,
'layer_norm'
,
'bias'
],
}
},
'learning_rate'
:
{
'type'
:
'polynomial'
,
'polynomial'
:
{
'initial_learning_rate'
:
3e-5
,
'end_learning_rate'
:
0.0
,
}
},
'warmup'
:
{
'type'
:
'polynomial'
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
official/
vision/beta
/train.py
→
official/
projects/longformer
/train.py
View file @
8b641b13
...
@@ -12,22 +12,19 @@
...
@@ -12,22 +12,19 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
"""A customized training library for the specific task."""
"""TensorFlow Model Garden Vision training driver."""
from
absl
import
app
from
absl
import
app
from
absl
import
flags
from
absl
import
flags
import
gin
import
gin
# pylint: disable=unused-import
from
official.common
import
registry_imports
# pylint: enable=unused-import
from
official.common
import
distribute_utils
from
official.common
import
distribute_utils
from
official.common
import
flags
as
tfm_flags
from
official.common
import
flags
as
tfm_flags
from
official.core
import
task_factory
from
official.core
import
task_factory
from
official.core
import
train_lib
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.core
import
train_utils
from
official.modeling
import
performance
from
official.modeling
import
performance
from
official.projects.longformer
import
longformer_experiments
# pylint: disable=unused-import
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
...
@@ -51,7 +48,9 @@ def main(_):
...
@@ -51,7 +48,9 @@ def main(_):
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
num_gpus
=
params
.
runtime
.
num_gpus
,
num_gpus
=
params
.
runtime
.
num_gpus
,
tpu_address
=
params
.
runtime
.
tpu
)
tpu_address
=
params
.
runtime
.
tpu
,
**
params
.
runtime
.
model_parallelism
())
with
distribution_strategy
.
scope
():
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
...
@@ -64,7 +63,7 @@ def main(_):
...
@@ -64,7 +63,7 @@ def main(_):
train_utils
.
save_gin_config
(
FLAGS
.
mode
,
model_dir
)
train_utils
.
save_gin_config
(
FLAGS
.
mode
,
model_dir
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
tfm_flags
.
define_flags
()
flags
.
mark_flags_as_required
([
'experiment'
,
'mode'
,
'model_dir'
])
app
.
run
(
main
)
app
.
run
(
main
)
official/projects/longformer/utils/convert_pretrained_pytorch_checkpoint_to_tf.py
0 → 100644
View file @
8b641b13
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Converts pre-trained pytorch checkpoint into a tf encoder checkpoint."""
import
os
from
absl
import
app
import
numpy
as
np
import
tensorflow
as
tf
import
transformers
from
official.modeling
import
tf_utils
from
official.projects.longformer.longformer
import
LongformerEncoderConfig
from
official.projects.longformer.longformer_encoder
import
LongformerEncoder
def
_get_pytorch_longformer_model
():
pretrained_lm
=
"allenai/longformer-base-4096"
model
=
transformers
.
AutoModel
.
from_pretrained
(
pretrained_lm
)
return
{
n
:
p
.
data
.
numpy
()
for
n
,
p
in
model
.
named_parameters
()}
def
_create_longformer_model
():
"""Creates a Longformer model."""
encoder_cfg
=
LongformerEncoderConfig
encoder_cfg
.
vocab_size
=
50265
encoder_cfg
.
max_position_embeddings
=
4098
encoder_cfg
.
attention_window
=
[
2
]
*
encoder_cfg
.
num_layers
encoder_cfg
.
global_attention_size
=
1
encoder
=
LongformerEncoder
(
attention_window
=
encoder_cfg
.
attention_window
,
global_attention_size
=
encoder_cfg
.
global_attention_size
,
vocab_size
=
encoder_cfg
.
vocab_size
,
hidden_size
=
encoder_cfg
.
hidden_size
,
num_layers
=
encoder_cfg
.
num_layers
,
num_attention_heads
=
encoder_cfg
.
num_attention_heads
,
inner_dim
=
encoder_cfg
.
intermediate_size
,
inner_activation
=
tf_utils
.
get_activation
(
encoder_cfg
.
hidden_activation
),
output_dropout
=
encoder_cfg
.
dropout_rate
,
attention_dropout
=
encoder_cfg
.
attention_dropout_rate
,
max_sequence_length
=
encoder_cfg
.
max_position_embeddings
,
type_vocab_size
=
encoder_cfg
.
type_vocab_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
output_range
=
encoder_cfg
.
output_range
,
embedding_width
=
encoder_cfg
.
embedding_size
,
norm_first
=
encoder_cfg
.
norm_first
)
return
encoder
# pylint: disable=protected-access
def
convert
(
encoder
,
allenai_model
):
"""Convert AllenAI Longformer to the one in the codebase."""
num_layers
=
encoder
.
_config
[
"num_layers"
]
num_attention_heads
=
encoder
.
_config
[
"num_attention_heads"
]
hidden_size
=
encoder
.
_config
[
"hidden_size"
]
head_size
=
hidden_size
//
num_attention_heads
assert
head_size
*
num_attention_heads
==
hidden_size
encoder
.
_embedding_layer
.
set_weights
(
[
allenai_model
[
"embeddings.word_embeddings.weight"
]])
encoder
.
_embedding_norm_layer
.
set_weights
([
allenai_model
[
"embeddings.LayerNorm.weight"
],
allenai_model
[
"embeddings.LayerNorm.bias"
]
])
encoder
.
_type_embedding_layer
.
set_weights
([
np
.
repeat
(
allenai_model
[
"embeddings.token_type_embeddings.weight"
],
2
,
axis
=
0
)
])
encoder
.
_position_embedding_layer
.
set_weights
(
[
allenai_model
[
"embeddings.position_embeddings.weight"
]])
encoder
.
_pooler_layer
.
set_weights
([
allenai_model
[
"pooler.dense.weight"
],
allenai_model
[
"pooler.dense.bias"
]
])
for
layer_num
in
range
(
num_layers
):
encoder
.
_transformer_layers
[
layer_num
].
_attention_layer
.
_global_key_dense
.
set_weights
([
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.self.key_global.weight"
].
T
.
reshape
(
(
hidden_size
,
num_attention_heads
,
head_size
)),
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.self.key_global.bias"
]
.
reshape
((
num_attention_heads
,
head_size
))
])
encoder
.
_transformer_layers
[
layer_num
].
_attention_layer
.
_global_query_dense
.
set_weights
([
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.self.query_global.weight"
]
.
T
.
reshape
((
hidden_size
,
num_attention_heads
,
head_size
)),
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.self.query_global.bias"
]
.
reshape
((
num_attention_heads
,
head_size
))
])
encoder
.
_transformer_layers
[
layer_num
].
_attention_layer
.
_global_value_dense
.
set_weights
([
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.self.value_global.weight"
]
.
T
.
reshape
((
hidden_size
,
num_attention_heads
,
head_size
)),
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.self.value_global.bias"
]
.
reshape
((
num_attention_heads
,
head_size
))
])
encoder
.
_transformer_layers
[
layer_num
].
_attention_layer
.
_key_dense
.
set_weights
([
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.self.key.weight"
].
T
.
reshape
(
(
hidden_size
,
num_attention_heads
,
head_size
)),
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.self.key_global.bias"
]
.
reshape
((
num_attention_heads
,
head_size
))
])
encoder
.
_transformer_layers
[
layer_num
].
_attention_layer
.
_query_dense
.
set_weights
([
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.self.query.weight"
].
T
.
reshape
((
hidden_size
,
num_attention_heads
,
head_size
)),
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.self.query.bias"
].
reshape
(
(
num_attention_heads
,
head_size
))
])
encoder
.
_transformer_layers
[
layer_num
].
_attention_layer
.
_value_dense
.
set_weights
([
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.self.value.weight"
].
T
.
reshape
((
hidden_size
,
num_attention_heads
,
head_size
)),
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.self.value.bias"
].
reshape
(
(
num_attention_heads
,
head_size
))
])
encoder
.
_transformer_layers
[
layer_num
].
_attention_layer
.
_output_dense
.
set_weights
([
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.output.dense.weight"
].
T
,
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.output.dense.bias"
]
])
encoder
.
_transformer_layers
[
layer_num
].
_attention_layer_norm
.
set_weights
([
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.output.LayerNorm.weight"
],
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.attention.output.LayerNorm.bias"
]
])
encoder
.
_transformer_layers
[
layer_num
].
_intermediate_dense
.
set_weights
([
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.intermediate.dense.weight"
].
T
,
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.intermediate.dense.bias"
]
])
encoder
.
_transformer_layers
[
layer_num
].
_output_dense
.
set_weights
([
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.output.dense.weight"
].
T
,
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.output.dense.bias"
]
])
encoder
.
_transformer_layers
[
layer_num
].
_output_layer_norm
.
set_weights
([
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.output.LayerNorm.weight"
],
allenai_model
[
f
"encoder.layer.
{
layer_num
}
.output.LayerNorm.bias"
]
])
def
convert_checkpoint
(
output_path
):
"""Converts and save the checkpoint."""
output_dir
,
_
=
os
.
path
.
split
(
output_path
)
tf
.
io
.
gfile
.
makedirs
(
output_dir
)
encoder
=
_create_longformer_model
()
allenai_model
=
_get_pytorch_longformer_model
()
sequence_length
=
128
batch_size
=
2
word_id_data
=
np
.
random
.
randint
(
10
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
mask_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
type_id_data
=
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
np
.
int32
)
inputs
=
{
"input_word_ids"
:
word_id_data
,
"input_mask"
:
mask_data
,
"input_type_ids"
:
type_id_data
,
}
encoder
(
inputs
)
convert
(
encoder
,
allenai_model
)
tf
.
train
.
Checkpoint
(
encoder
=
encoder
).
write
(
output_path
)
def
main
(
_
):
convert_checkpoint
(
"longformer-4096/longformer"
)
if
__name__
==
"__main__"
:
app
.
run
(
main
)
official/projects/longformer/utils/longformer_tokenizer_to_tfrecord.py
0 → 100644
View file @
8b641b13
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Longformer training examples to Tfrecord."""
import
collections
import
os
import
datasets
import
tensorflow
as
tf
import
transformers
pretrained_lm
=
"allenai/longformer-base-4096"
task_name
=
"mnli"
save_path
=
"./"
raw_datasets
=
datasets
.
load_dataset
(
"glue"
,
task_name
,
cache_dir
=
None
)
label_list
=
raw_datasets
[
"train"
].
features
[
"label"
].
names
num_labels
=
len
(
label_list
)
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
pretrained_lm
,
use_fast
=
True
,
)
task_to_keys
=
{
"cola"
:
(
"sentence"
,
None
),
"mnli"
:
(
"premise"
,
"hypothesis"
),
"mrpc"
:
(
"sentence1"
,
"sentence2"
),
"qnli"
:
(
"question"
,
"sentence"
),
"qqp"
:
(
"question1"
,
"question2"
),
"rte"
:
(
"sentence1"
,
"sentence2"
),
"sst2"
:
(
"sentence"
,
None
),
"stsb"
:
(
"sentence1"
,
"sentence2"
),
"wnli"
:
(
"sentence1"
,
"sentence2"
),
}
sentence1_key
,
sentence2_key
=
task_to_keys
[
task_name
]
padding
=
"max_length"
# make sure this is the same with model input size.
max_seq_length
=
512
def
preprocess_function
(
examples
):
# Tokenize the texts
args
=
((
examples
[
sentence1_key
],)
if
sentence2_key
is
None
else
(
examples
[
sentence1_key
],
examples
[
sentence2_key
]))
result
=
tokenizer
(
*
args
,
padding
=
padding
,
max_length
=
max_seq_length
,
truncation
=
True
)
return
result
raw_datasets
=
raw_datasets
.
map
(
preprocess_function
,
batched
=
True
,
desc
=
"Running tokenizer on dataset"
,
)
train_dataset
=
raw_datasets
[
"train"
]
eval_dataset
=
raw_datasets
[
"validation_matched"
if
task_name
==
"mnli"
else
"validation"
]
print
(
"train_dataset"
,
train_dataset
[
0
])
print
(
"eval_dataset"
,
eval_dataset
[
0
])
def
file_based_convert_examples_to_features
(
examples
,
output_file
):
"""Convert a set of `InputExample`s to a TFRecord file."""
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
output_file
))
writer
=
tf
.
io
.
TFRecordWriter
(
output_file
)
for
ex_index
,
example
in
enumerate
(
examples
):
if
ex_index
%
10000
==
0
:
print
(
f
"Writing example
{
ex_index
}
of
{
len
(
examples
)
}
"
)
def
create_int_feature
(
values
):
f
=
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
list
(
values
)))
return
f
features
=
collections
.
OrderedDict
()
features
[
"input_ids"
]
=
create_int_feature
(
example
[
"input_ids"
])
features
[
"input_mask"
]
=
create_int_feature
(
example
[
"attention_mask"
])
features
[
"segment_ids"
]
=
create_int_feature
([
0
]
*
len
(
example
[
"attention_mask"
]))
features
[
"label_ids"
]
=
create_int_feature
([
example
[
"label"
]])
features
[
"is_real_example"
]
=
create_int_feature
([
1
])
features
[
"example_id"
]
=
create_int_feature
([
example
[
"idx"
]])
tf_example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
writer
.
write
(
tf_example
.
SerializeToString
())
writer
.
close
()
file_based_convert_examples_to_features
(
train_dataset
,
os
.
path
.
join
(
save_path
,
f
"
{
pretrained_lm
.
replace
(
'/'
,
'_'
)
}
_train.tf_record"
))
file_based_convert_examples_to_features
(
eval_dataset
,
os
.
path
.
join
(
save_path
,
f
"
{
pretrained_lm
.
replace
(
'/'
,
'_'
)
}
_eval.tf_record"
))
official/projects/movinet/README.md
View file @
8b641b13
...
@@ -176,8 +176,7 @@ devices. See the [TF Lite Example](#tf-lite-example) to export and run your own
...
@@ -176,8 +176,7 @@ devices. See the [TF Lite Example](#tf-lite-example) to export and run your own
models. We also provide
[
quantized TF Lite binaries via TF Hub
](
https://tfhub.dev/s?deployment-format=lite&q=movinet
)
.
models. We also provide
[
quantized TF Lite binaries via TF Hub
](
https://tfhub.dev/s?deployment-format=lite&q=movinet
)
.
For reference, MoViNet-A0-Stream runs with a similar latency to
For reference, MoViNet-A0-Stream runs with a similar latency to
[MobileNetV3-Large]
[
MobileNetV3-Large
](
https://tfhub.dev/google/imagenet/mobilenet_v3_large_100_224/classification/
)
(https://tfhub.dev/google/imagenet/mobilenet_v3_large_100_224/classification/)
with +5% accuracy on Kinetics 600.
with +5% accuracy on Kinetics 600.
| Model Name | Input Shape | Pixel 4 Latency
\*
| x86 Latency
\*
| TF Lite Binary |
| Model Name | Input Shape | Pixel 4 Latency
\*
| x86 Latency
\*
| TF Lite Binary |
...
...
official/projects/movinet/modeling/movinet.py
View file @
8b641b13
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
"""Contains definitions of Mobile Video Networks.
"""Contains definitions of Mobile Video Networks.
Reference: https://arxiv.org/pdf/2103.11511.pdf
Reference: https://arxiv.org/pdf/2103.11511.pdf
...
...
official/projects/movinet/modeling/movinet_layers.py
View file @
8b641b13
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
"""Contains common building blocks for MoViNets.
"""Contains common building blocks for MoViNets.
Reference: https://arxiv.org/pdf/2103.11511.pdf
Reference: https://arxiv.org/pdf/2103.11511.pdf
...
...
official/projects/movinet/modeling/movinet_layers_test.py
View file @
8b641b13
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
"""Tests for movinet_layers.py."""
"""Tests for movinet_layers.py."""
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
...
...
official/projects/movinet/modeling/movinet_model_test.py
View file @
8b641b13
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
"""Tests for movinet_model.py."""
"""Tests for movinet_model.py."""
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
...
...
official/projects/movinet/modeling/movinet_test.py
View file @
8b641b13
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
"""Tests for movinet.py."""
"""Tests for movinet.py."""
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
...
...
official/projects/movinet/tools/export_saved_model.py
View file @
8b641b13
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
r
"""Exports models to tf.saved_model.
r
"""Exports models to tf.saved_model.
Export example:
Export example:
...
...
Prev
1
2
3
4
5
6
7
8
…
21
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment