Unverified Commit 8b641b13 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'tensorflow:master' into panoptic-deeplab

parents 7cffacfe 357fa547
task:
hub_module_url: ''
model:
num_classes: 3
encoder:
type: any
any:
max_position_embeddings: 512
attention_window: [32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32]
global_attention_size: 1
metric_type: 'accuracy'
train_data:
drop_remainder: true
global_batch_size: 32
input_path: TODO
is_training: true
seq_length: 128
validation_data:
drop_remainder: true
global_batch_size: 32
input_path: TODO
is_training: false
seq_length: 128
trainer:
checkpoint_interval: 1000
continuous_eval_timeout: 7200
optimizer_config:
learning_rate:
polynomial:
decay_steps: 61359
end_learning_rate: 0.0
initial_learning_rate: 3.0e-05
power: 1.0
type: polynomial
optimizer:
type: adamw
warmup:
polynomial:
power: 1
warmup_steps: 6136
type: polynomial
steps_per_loop: 100
summary_interval: 100
# Training data size 392,702 examples, 5 epochs.
train_steps: 61359
validation_interval: 2000
validation_steps: 307
task:
hub_module_url: ''
model:
num_classes: 3
encoder:
type: any
any:
max_position_embeddings: 4098
attention_window: [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128]
global_attention_size: 1
vocab_size: 50265
metric_type: 'accuracy'
train_data:
drop_remainder: true
global_batch_size: 32
input_path: TODO
is_training: true
seq_length: 512
validation_data:
drop_remainder: true
global_batch_size: 32
input_path: TODO
is_training: false
seq_length: 512
trainer:
checkpoint_interval: 1000
continuous_eval_timeout: 7200
optimizer_config:
learning_rate:
polynomial:
decay_steps: 61359
end_learning_rate: 0.0
initial_learning_rate: 3.0e-05
power: 1.0
type: polynomial
optimizer:
type: adamw
warmup:
polynomial:
power: 1
warmup_steps: 6136
type: polynomial
steps_per_loop: 1000
summary_interval: 1000
# Training data size 392,702 examples, 5 epochs.
train_steps: 61359
validation_interval: 2000
validation_steps: 307
task:
init_checkpoint: ""
model:
cls_heads:
[
{
activation: tanh,
cls_token_idx: 0,
dropout_rate: 0.1,
inner_dim: 768,
name: next_sentence,
num_classes: 2,
},
]
encoder:
type: any
any:
attention_dropout_rate: 0.1
dropout_rate: 0.1
embedding_size: 768
hidden_activation: gelu
hidden_size: 768
initializer_range: 0.02
intermediate_size: 3072
max_position_embeddings: 512
num_attention_heads: 12
num_layers: 12
type_vocab_size: 2
vocab_size: 30522
attention_window: [32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32]
global_attention_size: 1
train_data:
drop_remainder: true
global_batch_size: 256
input_path: gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00000-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00001-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00002-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00003-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00004-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00005-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00006-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00007-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00008-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00009-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00010-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00011-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00012-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00013-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00014-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00015-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00016-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00017-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00018-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00019-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00020-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00021-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00022-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00023-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00024-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00025-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00026-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00027-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00028-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00029-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00030-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00031-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00032-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00033-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00034-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00035-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00036-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00037-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00038-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00039-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00040-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00041-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00042-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00043-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00044-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00045-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00046-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00047-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00048-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00049-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00050-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00051-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00052-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00053-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00054-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00055-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00056-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00057-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00058-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00059-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00060-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00061-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00062-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00063-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00064-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00065-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00066-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00067-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00068-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00069-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00070-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00071-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00072-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00073-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00074-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00075-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00076-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00077-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00078-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00079-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00080-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00081-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00082-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00083-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00084-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00085-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00086-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00087-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00088-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00089-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00090-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00091-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00092-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00093-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00094-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00095-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00096-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00097-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00098-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00099-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00100-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00101-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00102-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00103-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00104-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00105-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00106-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00107-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00108-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00109-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00110-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00111-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00112-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00113-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00114-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00115-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00116-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00117-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00118-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00119-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00120-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00121-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00122-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00123-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00124-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00125-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00126-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00127-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00128-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00129-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00130-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00131-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00132-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00133-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00134-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00135-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00136-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00137-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00138-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00139-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00140-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00141-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00142-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00143-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00144-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00145-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00146-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00147-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00148-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00149-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00150-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00151-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00152-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00153-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00154-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00155-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00156-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00157-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00158-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00159-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00160-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00161-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00162-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00163-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00164-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00165-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00166-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00167-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00168-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00169-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00170-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00171-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00172-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00173-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00174-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00175-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00176-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00177-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00178-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00179-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00180-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00181-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00182-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00183-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00184-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00185-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00186-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00187-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00188-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00189-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00190-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00191-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00192-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00193-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00194-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00195-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00196-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00197-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00198-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00199-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00200-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00201-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00202-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00203-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00204-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00205-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00206-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00207-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00208-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00209-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00210-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00211-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00212-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00213-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00214-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00215-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00216-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00217-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00218-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00219-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00220-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00221-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00222-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00223-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00224-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00225-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00226-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00227-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00228-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00229-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00230-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00231-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00232-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00233-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00234-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00235-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00236-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00237-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00238-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00239-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00240-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00241-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00242-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00243-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00244-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00245-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00246-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00247-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00248-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00249-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00250-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00251-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00252-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00253-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00254-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00255-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00256-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00257-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00258-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00259-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00260-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00261-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00262-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00263-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00264-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00265-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00266-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00267-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00268-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00269-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00270-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00271-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00272-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00273-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00274-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00275-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00276-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00277-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00278-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00279-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00280-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00281-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00282-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00283-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00284-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00285-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00286-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00287-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00288-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00289-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00290-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00291-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00292-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00293-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00294-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00295-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00296-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00297-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00298-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00299-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00300-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00301-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00302-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00303-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00304-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00305-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00306-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00307-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00308-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00309-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00310-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00311-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00312-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00313-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00314-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00315-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00316-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00317-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00318-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00319-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00320-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00321-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00322-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00323-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00324-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00325-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00326-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00327-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00328-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00329-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00330-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00331-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00332-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00333-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00334-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00335-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00336-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00337-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00338-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00339-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00340-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00341-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00342-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00343-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00344-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00345-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00346-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00347-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00348-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00349-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00350-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00351-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00352-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00353-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00354-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00355-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00356-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00357-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00358-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00359-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00360-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00361-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00362-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00363-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00364-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00365-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00366-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00367-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00368-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00369-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00370-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00371-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00372-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00373-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00374-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00375-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00376-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00377-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00378-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00379-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00380-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00381-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00382-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00383-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00384-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00385-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00386-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00387-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00388-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00389-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00390-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00391-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00392-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00393-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00394-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00395-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00396-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00397-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00398-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00399-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00400-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00401-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00402-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00403-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00404-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00405-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00406-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00407-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00408-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00409-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00410-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00411-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00412-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00413-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00414-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00415-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00416-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00417-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00418-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00419-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00420-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00421-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00422-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00423-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00424-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00425-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00426-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00427-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00428-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00429-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00430-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00431-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00432-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00433-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00434-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00435-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00436-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00437-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00438-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00439-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00440-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00441-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00442-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00443-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00444-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00445-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00446-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00447-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00448-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00449-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00450-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00451-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00452-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00453-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00454-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00455-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00456-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00457-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00458-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00459-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00460-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00461-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00462-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00463-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00464-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00465-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00466-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00467-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00468-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00469-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00470-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00471-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00472-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00473-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00474-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00475-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00476-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00477-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00478-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00479-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00480-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00481-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00482-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00483-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00484-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00485-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00486-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00487-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00488-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00489-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00490-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00491-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00492-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00493-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00494-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00495-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00496-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00497-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00498-of-00500,gs://tf_model_garden/nlp/data/research_data/bert_pretrain/wikipedia.tfrecord-00499-of-00500
is_training: true
max_predictions_per_seq: 76
seq_length: 512
use_next_sentence_label: true
use_position_id: false
validation_data:
drop_remainder: true
global_batch_size: 256
input_path: TODO
is_training: false
max_predictions_per_seq: 76
seq_length: 512
use_next_sentence_label: true
use_position_id: false
trainer:
checkpoint_interval: 20000
max_to_keep: 5
optimizer_config:
learning_rate:
polynomial:
cycle: false
decay_steps: 1000000
end_learning_rate: 0.0
initial_learning_rate: 0.0001
power: 1.0
type: polynomial
optimizer:
type: adamw
warmup:
polynomial:
power: 1
warmup_steps: 10000
type: polynomial
steps_per_loop: 50
summary_interval: 50
train_steps: 1000000
validation_interval: 1000
validation_steps: 64
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Longformer model configurations and instantiation methods."""
import dataclasses
from typing import List
import tensorflow as tf
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.nlp.configs import encoders
from official.projects.longformer.longformer_encoder import LongformerEncoder
@dataclasses.dataclass
class LongformerEncoderConfig(encoders.BertEncoderConfig):
"""Extra paramerters for Longformer configs.
Attributes:
attention_window: list of ints representing the window size for each layer.
global_attention_size: the size of global attention used for each token.
pad_token_id: the token id for the pad token
"""
attention_window: List[int] = dataclasses.field(default_factory=list)
global_attention_size: int = 0
pad_token_id: int = 1
@base_config.bind(LongformerEncoderConfig)
def get_encoder(encoder_cfg: LongformerEncoderConfig):
"""Gets a 'LongformerEncoder' object.
Args:
encoder_cfg: A 'LongformerEncoderConfig'.
Returns:
A encoder object.
"""
encoder = LongformerEncoder(
attention_window=encoder_cfg.attention_window,
global_attention_size=encoder_cfg.global_attention_size,
vocab_size=encoder_cfg.vocab_size,
hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers,
num_attention_heads=encoder_cfg.num_attention_heads,
inner_dim=encoder_cfg.intermediate_size,
inner_activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
output_dropout=encoder_cfg.dropout_rate,
attention_dropout=encoder_cfg.attention_dropout_rate,
max_sequence_length=encoder_cfg.max_position_embeddings,
type_vocab_size=encoder_cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
output_range=encoder_cfg.output_range,
embedding_width=encoder_cfg.embedding_size,
norm_first=encoder_cfg.norm_first)
return encoder
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Longformer attention block. Modified From huggingface/transformers."""
# pylint: disable=g-classes-have-attributes
import math
import string
import numpy as np
import tensorflow as tf
from official.modeling.tf_utils import get_shape_list
_CHR_IDX = string.ascii_lowercase
def _build_attention_equation(rank, attn_axes):
"""Builds einsum equations for the attention computation.
Query, key, value inputs after projection are expected to have the shape as:
`(bs, <non-attention dims>, <attention dims>, num_heads, channels)`.
`bs` and `<non-attention dims>` are treated as `<batch dims>`.
The attention operations can be generalized:
(1) Query-key dot product:
`(<batch dims>, <query attention dims>, num_heads, channels), (<batch dims>,
<key attention dims>, num_heads, channels) -> (<batch dims>,
num_heads, <query attention dims>, <key attention dims>)`
(2) Combination:
`(<batch dims>, num_heads, <query attention dims>, <key attention dims>),
(<batch dims>, <value attention dims>, num_heads, channels) -> (<batch dims>,
<query attention dims>, num_heads, channels)`
Args:
rank: Rank of query, key, value tensors.
attn_axes: List/tuple of axes, `[-1, rank)`, that attention will be applied
to.
Returns:
Einsum equations.
"""
target_notation = _CHR_IDX[:rank]
# `batch_dims` includes the head dim.
batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,)))
letter_offset = rank
source_notation = ""
for i in range(rank):
if i in batch_dims or i == rank - 1:
source_notation += target_notation[i]
else:
source_notation += _CHR_IDX[letter_offset]
letter_offset += 1
product_notation = "".join([target_notation[i] for i in batch_dims] +
[target_notation[i] for i in attn_axes] +
[source_notation[i] for i in attn_axes])
dot_product_equation = f"{source_notation},{target_notation}->{product_notation}"
attn_scores_rank = len(product_notation)
combine_equation = f"{product_notation},{source_notation}->{target_notation}"
return dot_product_equation, combine_equation, attn_scores_rank
def _build_proj_equation(free_dims, bound_dims, output_dims):
"""Builds an einsum equation for projections inside multi-head attention."""
input_str = ""
kernel_str = ""
output_str = ""
bias_axes = ""
letter_offset = 0
for i in range(free_dims):
char = _CHR_IDX[i + letter_offset]
input_str += char
output_str += char
letter_offset += free_dims
for i in range(bound_dims):
char = _CHR_IDX[i + letter_offset]
input_str += char
kernel_str += char
letter_offset += bound_dims
for i in range(output_dims):
char = _CHR_IDX[i + letter_offset]
kernel_str += char
output_str += char
bias_axes += char
equation = f"{input_str},{kernel_str}->{output_str}"
return equation, bias_axes, len(output_str)
def _get_output_shape(output_rank, known_last_dims):
return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims)
@tf.keras.utils.register_keras_serializable(package="Text")
class LongformerAttention(tf.keras.layers.MultiHeadAttention):
"""LongformerAttention.
Args:
attention_window: int representing the window size for attention.
layer_id: int of the id of the layer.
global_attention_size: the size of global attention used for each token.
"""
def __init__(self, attention_window, layer_id, global_attention_size,
**kwargs):
super().__init__(**kwargs)
self._layer_id = layer_id
self._attention_window = attention_window
assert (self._attention_window % 2 == 0), (
f"`attention_window` for layer {self._layer_id} has to be an even "
f"value. Given {self.attention_window}")
assert (self._attention_window > 0), (
f"`attention_window` for layer {self._layer_id} has to be positive. "
f"Given {self.attention_window}")
self._one_sided_attn_window_size = self._attention_window // 2
self.global_attention_size = global_attention_size
def _build_from_signature(self, query, value, key=None):
"""Builds layers and variables.
Once the method is called, self._built_from_signature will be set to True.
Args:
query: Query tensor or TensorShape.
value: Value tensor or TensorShape.
key: Key tensor or TensorShape.
"""
self._built_from_signature = True
if hasattr(query, "shape"):
self._query_shape = tf.TensorShape(query.shape)
else:
self._query_shape = tf.TensorShape(query)
if hasattr(value, "shape"):
self._value_shape = tf.TensorShape(value.shape)
else:
self._value_shape = tf.TensorShape(value)
if key is None:
self._key_shape = self._value_shape
elif hasattr(key, "shape"):
self._key_shape = tf.TensorShape(key.shape)
else:
self._key_shape = tf.TensorShape(key)
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
# Any setup work performed only once should happen in an `init_scope`
# to avoid creating symbolic Tensors that will later pollute any eager
# operations.
# with tf_utils.maybe_init_scope(self):
# TODO(crickwu): check whether tf_utils.maybe_init_scope(self) (keras)
# is needed.
free_dims = self._query_shape.rank - 1
einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=1, output_dims=2)
self._query_dense = tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="query",
**common_kwargs)
self._global_query_dense = tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="global_query",
**common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
self._key_shape.rank - 1, bound_dims=1, output_dims=2)
self._key_dense = tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="key",
**common_kwargs)
self._global_key_dense = tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="global_key",
**common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
self._value_shape.rank - 1, bound_dims=1, output_dims=2)
self._value_dense = tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._value_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="value",
**common_kwargs)
self._global_value_dense = tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._value_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="global_value",
**common_kwargs)
# Builds the attention computations for multi-head dot product attention.
# These computations could be wrapped into the keras attention layer once
# it support mult-head einsum computations.
self._build_attention(output_rank)
self._global_dropout_layer = tf.keras.layers.Dropout(rate=self._dropout)
# self._output_dense = self._make_output_dense(
# free_dims, common_kwargs, "attention_output")
self._output_dense = tf.keras.layers.Dense(
units=self._num_heads * self._key_dim, name="dense", **common_kwargs)
def call(self,
hidden_states,
attention_mask=None,
is_index_masked=None,
is_index_global_attn=None,
training=None):
"""Applies Dot-product attention with query, key, value tensors.
This function defines the computation inside `call` with projected
multi-head Q, K, V inputs. Users can override this function for customized
attention implementation.
Args:
hidden_states: inputs for generating query, key and value tensors.
attention_mask: a boolean mask of shape `(B, T, S)`, that prevents
attention to certain positions.
is_index_masked: boolean indicating whether the index is masked.
is_index_global_attn: boolean indicating whether the index is global
attention.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
Returns:
attention_output: Multi-headed outputs of attention computation.
"""
if not self._built_from_signature:
self._build_from_signature(
query=hidden_states, value=hidden_states, key=hidden_states)
# N = `num_attention_heads`
# H = `size_per_head`
# `query` = [B, T, N ,H]
query = self._query_dense(hidden_states)
# `key` = [B, S, N, H]
key = self._key_dense(hidden_states)
# `value` = [B, S, N, H]
value = self._value_dense(hidden_states)
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_dim)))
batch_size, seq_len, num_heads, head_dim = get_shape_list(query)
# attn_probs = (batch_size, seq_len, num_heads, window*2+1)
attn_scores = self._sliding_chunks_query_key_matmul(
query, key, self._one_sided_attn_window_size)
# diagonal mask with zeros everywhere and -inf inplace of padding
diagonal_mask = self._sliding_chunks_query_key_matmul(
tf.ones(get_shape_list(attention_mask)),
attention_mask,
self._one_sided_attn_window_size,
)
# pad local attention probs
attn_scores += diagonal_mask
if tf.executing_eagerly():
tf.debugging.assert_equal(
get_shape_list(attn_scores),
[
batch_size, seq_len, self._num_heads,
self._one_sided_attn_window_size * 2 + 1
],
message=f"attn_probs should be of size "
f"({batch_size}, {seq_len}, {num_heads}, "
f"{self._one_sided_attn_window_size * 2 + 1}),"
f" but is of size {get_shape_list(attn_scores)}",
)
# compute global attn indices required through out forward fn
(
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
) = self._get_global_attn_indices(is_index_global_attn,
self.global_attention_size)
# this function is only relevant for global attention
if self.global_attention_size > 0:
attn_scores = self._concat_with_global_key_attn_probs(
attn_scores=attn_scores,
query_vectors=query,
key_vectors=key,
max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
)
else:
pass
attn_probs = tf.nn.softmax(attn_scores, axis=-1)
# softmax sometimes inserts NaN if all positions are masked,
# replace them with 0
# Make sure to create a mask with the proper shape:
# if is_global_attn==True => [batch_size, seq_len, self.num_heads,
# self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
# if is_global_attn==False => [batch_size, seq_len, self.num_heads,
# self.one_sided_attn_window_size * 2 + 1]
if self.global_attention_size > 0:
masked_index = tf.tile(
is_index_masked[:, :, None, None],
(1, 1, self._num_heads, self._one_sided_attn_window_size * 2 +
max_num_global_attn_indices + 1),
)
else:
masked_index = tf.tile(
is_index_masked[:, :, None, None],
(1, 1, self._num_heads, self._one_sided_attn_window_size * 2 + 1),
)
attn_probs = tf.where(
masked_index,
tf.zeros(get_shape_list(masked_index), dtype=attn_probs.dtype),
attn_probs,
)
layer_head_mask = None
if layer_head_mask is not None:
if tf.executing_eagerly():
tf.debugging.assert_equal(
get_shape_list(layer_head_mask),
[self._num_heads],
message=f"Head mask for a single layer should be of size "
f"{(self._num_heads)}, but is "
f"{get_shape_list(layer_head_mask)}",
)
attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
# apply dropout
attn_probs = self._dropout_layer(attn_probs, training=training)
value_vectors = tf.reshape(
value, (batch_size, seq_len, self._num_heads, self._key_dim))
# if global attention, compute sum of global and local attn
if self.global_attention_size > 0:
attn_output = self._compute_attn_output_with_global_indices(
value_vectors=value_vectors,
attn_probs=attn_probs,
max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
)
else:
attn_output = self._sliding_chunks_matmul_attn_probs_value(
attn_probs, value_vectors, self._one_sided_attn_window_size)
if tf.executing_eagerly():
tf.debugging.assert_equal(
get_shape_list(attn_output),
[batch_size, seq_len, self._num_heads, head_dim],
message="Unexpected size",
)
attn_output = tf.reshape(
attn_output,
(batch_size, seq_len, self._num_heads * self._key_dim)) # FIXME
# compute value for global attention and overwrite to attention output
# TODO(crickwu): remove the redundant computation
if self.global_attention_size > 0:
attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden( # pylint: disable=unused-variable
attn_output=attn_output,
hidden_states=hidden_states,
max_num_global_attn_indices=max_num_global_attn_indices,
layer_head_mask=layer_head_mask,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
is_index_masked=is_index_masked,
training=training,
)
else:
global_attn_probs = tf.zeros(
(batch_size, self._num_heads, max_num_global_attn_indices, seq_len))
# make sure that local attention probabilities are set to 0 for indices of
# global attn
if self.global_attention_size > 0:
masked_global_attn_index = tf.tile(
is_index_global_attn[:, :, None, None],
(1, 1, self._num_heads, self._one_sided_attn_window_size * 2 +
max_num_global_attn_indices + 1),
)
else:
masked_global_attn_index = tf.tile(
is_index_global_attn[:, :, None, None],
(1, 1, self._num_heads, self._one_sided_attn_window_size * 2 + 1),
)
attn_probs = tf.where(
masked_global_attn_index,
tf.zeros(
get_shape_list(masked_global_attn_index), dtype=attn_probs.dtype),
attn_probs,
)
# we can return extra information here
# (attn_output, attn_probs, global_attn_probs)
return attn_output
def get_config(self):
config = {
"layer_id": self._layer_id,
"attention_window": self._one_sided_attn_window_size,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def _sliding_chunks_query_key_matmul(self, query, key, window_overlap):
"""Matrix multiplication of query and key tensors.
This multiplication uses a sliding window attention pattern.
This implementation splits the input into overlapping chunks of size
2w (e.g. 512 for pretrained Longformer) with an overlap of size
window_overlap.
Args:
query: query tensor.
key: key tensor.
window_overlap: int.
Returns:
diagonal_attention_scores: tensor.
"""
batch_size, seq_len, num_heads, head_dim = get_shape_list(query)
if tf.executing_eagerly():
tf.debugging.assert_equal(
seq_len % (window_overlap * 2),
0,
message=f"Sequence length should be multiple of {window_overlap * 2}. "
f"Given {seq_len}",
)
tf.debugging.assert_equal(
get_shape_list(query),
get_shape_list(key),
message=f"Shape of query and key should be equal, but got query: "
f"{get_shape_list(query)} and key: {get_shape_list(key)}",
)
chunks_count = seq_len // window_overlap - 1
# group batch_size and num_heads dimensions into one,
# then chunk seq_len into chunks of size window_overlap * 2
query = tf.reshape(
tf.transpose(query, (0, 2, 1, 3)),
(batch_size * num_heads, seq_len, head_dim),
)
key = tf.reshape(
tf.transpose(key, (0, 2, 1, 3)),
(batch_size * num_heads, seq_len, head_dim))
chunked_query = self._chunk(query, window_overlap)
chunked_key = self._chunk(key, window_overlap)
# matrix multiplication
# bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap
chunked_query = tf.cast(chunked_query, dtype=chunked_key.dtype)
chunked_attention_scores = tf.einsum("bcxd,bcyd->bcxy", chunked_query,
chunked_key) # multiply
# convert diagonals into columns
paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 1], [0, 0]])
diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(
chunked_attention_scores, paddings)
# allocate space for the overall attention matrix where the chunks are
# combined. The last dimension
# has (window_overlap * 2 + 1) columns. The first (window_overlap) columns
# are the window_overlap lower triangles (attention from a word to
# window_overlap previous words). The following column is attention score
# from each word to itself, then
# followed by window_overlap columns for the upper triangle.
# copy parts from diagonal_chunked_attention_scores into the combined matrix
# of attentions - copying the main diagonal and the upper triangle
# TODO(crickwu): This code is most likely not very efficient and should be
# improved.
diagonal_attn_scores_up_triang = tf.concat(
[
diagonal_chunked_attention_scores[:, :, :window_overlap, :
window_overlap + 1],
diagonal_chunked_attention_scores[:, -1:,
window_overlap:, :window_overlap +
1],
],
axis=1,
)
# - copying the lower triangle
diagonal_attn_scores_low_triang = tf.concat(
[
tf.zeros(
(batch_size * num_heads, 1, window_overlap, window_overlap),
dtype=diagonal_chunked_attention_scores.dtype,
),
diagonal_chunked_attention_scores[:, :, -(window_overlap + 1):-1,
window_overlap + 1:],
],
axis=1,
)
diagonal_attn_scores_first_chunk = tf.concat(
[
tf.roll(
diagonal_chunked_attention_scores,
shift=[1, window_overlap],
axis=[2, 3],
)[:, :, :window_overlap, :window_overlap],
tf.zeros(
(batch_size * num_heads, 1, window_overlap, window_overlap),
dtype=diagonal_chunked_attention_scores.dtype,
),
],
axis=1,
)
first_chunk_mask = (
tf.tile(
tf.range(chunks_count + 1)[None, :, None, None],
(batch_size * num_heads, 1, window_overlap, window_overlap),
) < 1)
diagonal_attn_scores_low_triang = tf.where(
first_chunk_mask,
diagonal_attn_scores_first_chunk,
diagonal_attn_scores_low_triang,
)
# merging upper and lower triangle
diagonal_attention_scores = tf.concat(
[diagonal_attn_scores_low_triang, diagonal_attn_scores_up_triang],
axis=-1)
# separate batch_size and num_heads dimensions again
diagonal_attention_scores = tf.transpose(
tf.reshape(
diagonal_attention_scores,
(batch_size, num_heads, seq_len, 2 * window_overlap + 1),
),
(0, 2, 1, 3),
)
diagonal_attention_scores = self._mask_invalid_locations(
diagonal_attention_scores, window_overlap)
return diagonal_attention_scores
@staticmethod
def _mask_invalid_locations(input_tensor, window_overlap):
# create correct upper triangle bool mask
mask_2d_upper = tf.reverse(
tf.linalg.band_part(
tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0),
axis=[0],
)
# pad to full matrix
padding = tf.convert_to_tensor(
[[0, get_shape_list(input_tensor)[1] - window_overlap],
[0, get_shape_list(input_tensor)[3] - window_overlap - 1]])
# create lower mask
mask_2d = tf.pad(mask_2d_upper, padding)
# combine with upper mask
mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1])
# broadcast to full matrix
mask_4d = tf.tile(mask_2d[None, :, None, :],
(get_shape_list(input_tensor)[0], 1, 1, 1))
# inf tensor used for masking
inf_tensor = -float("inf") * tf.ones_like(input_tensor)
# mask
input_tensor = tf.where(
tf.math.greater(mask_4d, 0), inf_tensor, input_tensor)
return input_tensor
def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value,
window_overlap):
"""Same as _sliding_chunks_query_key_matmul but for attn_probs and value."""
batch_size, seq_len, num_heads, head_dim = get_shape_list(value)
if tf.executing_eagerly():
tf.debugging.assert_equal(
seq_len % (window_overlap * 2),
0,
message="Seq_len has to be multiple of 2 * window_overlap",
)
tf.debugging.assert_equal(
get_shape_list(attn_probs)[:3],
get_shape_list(value)[:3],
message="value and attn_probs must have same dims (except head_dim)",
)
tf.debugging.assert_equal(
get_shape_list(attn_probs)[3],
2 * window_overlap + 1,
message="attn_probs last dim has to be 2 * window_overlap + 1",
)
chunks_count = seq_len // window_overlap - 1
# group batch_size and num_heads dimensions into one, then chunk seq_len
# into chunks of size 2 window overlap
chunked_attn_probs = tf.reshape(
tf.transpose(attn_probs, (0, 2, 1, 3)),
(
batch_size * num_heads,
seq_len // window_overlap,
window_overlap,
2 * window_overlap + 1,
),
)
# group batch_size and num_heads dimensions into one
value = tf.reshape(
tf.transpose(value, (0, 2, 1, 3)),
(batch_size * num_heads, seq_len, head_dim),
)
# pad seq_len with w at the beginning of the sequence and another window
# overlap at the end
paddings = tf.convert_to_tensor([[0, 0], [window_overlap, window_overlap],
[0, 0]])
padded_value = tf.pad(value, paddings, constant_values=-1)
# chunk padded_value into chunks of size 3 window overlap and an overlap of
# size window overlap
frame_size = 3 * window_overlap * head_dim
frame_hop_size = (get_shape_list(padded_value)[1] * head_dim -
frame_size) // chunks_count
chunked_value = tf.signal.frame(
tf.reshape(padded_value, (batch_size * num_heads, -1)),
frame_size,
frame_hop_size,
)
chunked_value = tf.reshape(
chunked_value,
(batch_size * num_heads, chunks_count + 1, 3 * window_overlap,
head_dim),
)
if tf.executing_eagerly():
tf.debugging.assert_equal(
get_shape_list(chunked_value),
[
batch_size * num_heads, chunks_count + 1, 3 * window_overlap,
head_dim
],
message="Chunked value has the wrong shape",
)
chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)
context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value)
context = tf.transpose(
tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)),
(0, 2, 1, 3),
)
return context
@staticmethod
def _pad_and_transpose_last_two_dims(hidden_states_padded, paddings):
"""Pads rows and then flips rows and columns."""
hidden_states_padded = tf.pad(
hidden_states_padded, paddings
) # padding value is not important because it will be overwritten
batch_size, chunk_size, seq_length, hidden_dim = get_shape_list(
hidden_states_padded)
hidden_states_padded = tf.reshape(
hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length))
return hidden_states_padded
@staticmethod
def _pad_and_diagonalize(chunked_hidden_states):
"""Shifts every row 1 step right, converting columns into diagonals.
Example::
chunked_hidden_states: [ 0.4983, 2.6918, -0.0071, 1.0492,
-1.8348, 0.7672, 0.2986, 0.0285,
-0.7584, 0.4206, -0.0405, 0.1599,
2.0514, -1.1600, 0.5372, 0.2629 ]
window_overlap = num_rows = 4
(pad & diagonalize) =>
[ 0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000
0.0000, -1.8348, 0.7672, 0.2986, 0.0285, 0.0000, 0.0000
0.0000, 0.0000, -0.7584, 0.4206, -0.0405, 0.1599, 0.0000
0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ]
Args:
chunked_hidden_states: tensor.
Returns:
padded_hidden_stategs: tensor.
"""
total_num_heads, num_chunks, window_overlap, hidden_dim = get_shape_list(
chunked_hidden_states)
paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 0],
[0, window_overlap + 1]])
chunked_hidden_states = tf.pad(chunked_hidden_states, paddings)
chunked_hidden_states = tf.reshape(chunked_hidden_states,
(total_num_heads, num_chunks, -1))
chunked_hidden_states = chunked_hidden_states[:, :, :-window_overlap]
chunked_hidden_states = tf.reshape(
chunked_hidden_states,
(total_num_heads, num_chunks, window_overlap,
window_overlap + hidden_dim),
)
chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
return chunked_hidden_states
@staticmethod
def _chunk(hidden_states, window_overlap):
"""convert into overlapping chunks. Chunk size = 2w, overlap size = w."""
batch_size, seq_length, hidden_dim = get_shape_list(hidden_states)
num_output_chunks = 2 * (seq_length // (2 * window_overlap)) - 1
# define frame size and frame stride (similar to convolution)
frame_hop_size = window_overlap * hidden_dim
frame_size = 2 * frame_hop_size
hidden_states = tf.reshape(hidden_states,
(batch_size, seq_length * hidden_dim))
# chunk with overlap
chunked_hidden_states = tf.signal.frame(hidden_states, frame_size,
frame_hop_size)
if tf.executing_eagerly():
tf.debugging.assert_equal(
get_shape_list(chunked_hidden_states),
[batch_size, num_output_chunks, frame_size],
message=f"Make sure chunking is correctly applied. `Chunked hidden "
f"states should have output dimension"
f" {[batch_size, frame_size, num_output_chunks]}, but got "
f"{get_shape_list(chunked_hidden_states)}.",
)
chunked_hidden_states = tf.reshape(
chunked_hidden_states,
(batch_size, num_output_chunks, 2 * window_overlap, hidden_dim),
)
return chunked_hidden_states
@staticmethod
def _get_global_attn_indices(is_index_global_attn, global_attention_size):
"""Computes global attn indices required throughout forward pass."""
# All global attention size are fixed through global_attention_size
batch_size, _ = get_shape_list(is_index_global_attn)
max_num_global_attn_indices = global_attention_size
row_indices = tf.range(batch_size)
row_indices = tf.repeat(
tf.expand_dims(row_indices, axis=0),
repeats=[global_attention_size],
axis=0)
row_indices = tf.reshape(row_indices,
(batch_size * global_attention_size, 1))
col_indices = tf.range(global_attention_size)
col_indices = tf.repeat(
tf.expand_dims(col_indices, axis=1), repeats=[batch_size], axis=0)
is_index_global_attn_nonzero = tf.concat((row_indices, col_indices), axis=1)
# this is actually same as `is_index_global_attn_nonzero`,
# since we assume all global attention are the same size
is_local_index_global_attn_nonzero = tf.concat((row_indices, col_indices),
axis=1)
# empty tensor
is_local_index_no_global_attn_nonzero = tf.reshape(
tf.expand_dims(tf.range(0), axis=1), (0, 2))
return (
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
)
def _concat_with_global_key_attn_probs(
self,
attn_scores,
key_vectors,
query_vectors,
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
):
batch_size = get_shape_list(key_vectors)[0]
# select global key vectors
global_key_vectors = tf.gather_nd(key_vectors, is_index_global_attn_nonzero)
# create only global key vectors
key_vectors_only_global = tf.scatter_nd(
is_local_index_global_attn_nonzero,
global_key_vectors,
shape=(
batch_size,
max_num_global_attn_indices,
self._num_heads,
self._key_dim,
),
)
# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
attn_probs_from_global_key = tf.einsum("blhd,bshd->blhs", query_vectors,
key_vectors_only_global)
# (batch_size, max_num_global_attn_indices, seq_len, num_heads)
attn_probs_from_global_key_trans = tf.transpose(attn_probs_from_global_key,
(0, 3, 1, 2))
mask_shape = (
get_shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple(
get_shape_list(attn_probs_from_global_key_trans)[-2:])
mask = tf.ones(mask_shape) * -10000.0
mask = tf.cast(mask, dtype=attn_probs_from_global_key_trans.dtype)
# scatter mask
attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update(
attn_probs_from_global_key_trans,
is_local_index_no_global_attn_nonzero,
mask,
)
# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
attn_probs_from_global_key = tf.transpose(attn_probs_from_global_key_trans,
(0, 2, 3, 1))
# concat to attn_probs
# (batch_size, seq_len, num_heads, extra attention count + 2*window+1)
attn_scores = tf.concat((attn_probs_from_global_key, attn_scores), axis=-1)
return attn_scores
def _compute_attn_output_with_global_indices(
self,
value_vectors,
attn_probs,
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
):
batch_size = get_shape_list(attn_probs)[0]
# cut local attn probs to global only
attn_probs_only_global = attn_probs[:, :, :, :max_num_global_attn_indices]
# select global value vectors
global_value_vectors = tf.gather_nd(value_vectors,
is_index_global_attn_nonzero)
# create only global value vectors
value_vectors_only_global = tf.scatter_nd(
is_local_index_global_attn_nonzero,
global_value_vectors,
shape=(
batch_size,
max_num_global_attn_indices,
self._num_heads,
self._key_dim,
),
)
# compute attn output only global
attn_output_only_global = tf.einsum("blhs,bshd->blhd",
attn_probs_only_global,
value_vectors_only_global)
# reshape attn probs
attn_probs_without_global = attn_probs[:, :, :,
max_num_global_attn_indices:]
# compute attn output with global
attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value(
attn_probs_without_global, value_vectors,
self._one_sided_attn_window_size)
return attn_output_only_global + attn_output_without_global
def _compute_global_attn_output_from_hidden(
self,
attn_output,
hidden_states,
max_num_global_attn_indices,
layer_head_mask,
is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
is_index_masked,
training,
):
batch_size, seq_len = get_shape_list(hidden_states)[:2]
# prepare global hidden states
global_attn_hidden_states = tf.gather_nd(hidden_states,
is_index_global_attn_nonzero)
global_attn_hidden_states = tf.scatter_nd(
is_local_index_global_attn_nonzero,
global_attn_hidden_states,
shape=(batch_size, max_num_global_attn_indices,
self._num_heads * self._key_dim),
)
# global key, query, value
global_query_vectors_only_global = self._global_query_dense(
global_attn_hidden_states)
global_key_vectors = self._global_key_dense(hidden_states)
global_value_vectors = self._global_value_dense(hidden_states)
# normalize
global_query_vectors_only_global /= tf.math.sqrt(
tf.cast(self._key_dim, dtype=global_query_vectors_only_global.dtype))
global_query_vectors_only_global = self.reshape_and_transpose(
global_query_vectors_only_global, batch_size)
global_key_vectors = self.reshape_and_transpose(global_key_vectors,
batch_size)
global_value_vectors = self.reshape_and_transpose(global_value_vectors,
batch_size)
# compute attn scores
global_attn_scores = tf.matmul(
global_query_vectors_only_global, global_key_vectors, transpose_b=True)
if tf.executing_eagerly():
tf.debugging.assert_equal(
get_shape_list(global_attn_scores),
[batch_size * self._num_heads, max_num_global_attn_indices, seq_len],
message=f"global_attn_scores have the wrong size. Size should be"
f"{(batch_size * self._num_heads, max_num_global_attn_indices, seq_len)}, "
f"but is {get_shape_list(global_attn_scores)}.",
)
global_attn_scores = tf.reshape(
global_attn_scores,
(batch_size, self._num_heads, max_num_global_attn_indices, seq_len),
)
global_attn_scores_trans = tf.transpose(global_attn_scores, (0, 2, 1, 3))
mask_shape = (get_shape_list(is_local_index_no_global_attn_nonzero)[0],
) + tuple(get_shape_list(global_attn_scores_trans)[-2:])
global_attn_mask = tf.ones(mask_shape) * -10000.0
global_attn_mask = tf.cast(
global_attn_mask, dtype=global_attn_scores_trans.dtype)
# scatter mask
global_attn_scores_trans = tf.tensor_scatter_nd_update(
global_attn_scores_trans,
is_local_index_no_global_attn_nonzero,
global_attn_mask,
)
global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3))
# mask global attn scores
attn_mask = tf.tile(is_index_masked[:, None, None, :],
(1, get_shape_list(global_attn_scores)[1], 1, 1))
global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores)
global_attn_scores = tf.reshape(
global_attn_scores,
(batch_size * self._num_heads, max_num_global_attn_indices, seq_len),
)
# compute global attn probs
global_attn_probs_float = tf.nn.softmax(global_attn_scores, axis=-1)
# apply layer head masking
if layer_head_mask is not None:
if tf.executing_eagerly():
tf.debugging.assert_equal(
get_shape_list(layer_head_mask),
[self._num_heads],
message=f"Head mask for a single layer should be of size "
f"{(self._num_heads)}, but is {get_shape_list(layer_head_mask)}",
)
global_attn_probs_float = tf.reshape(
layer_head_mask,
(1, -1, 1, 1)) * tf.reshape(global_attn_probs_float,
(batch_size, self._num_heads,
max_num_global_attn_indices, seq_len))
global_attn_probs_float = tf.reshape(
global_attn_probs_float,
(batch_size * self._num_heads, max_num_global_attn_indices, seq_len))
# dropout
global_attn_probs = self._global_dropout_layer(
global_attn_probs_float, training=training)
# global attn output
global_attn_output = tf.matmul(global_attn_probs, global_value_vectors)
if tf.executing_eagerly():
tf.debugging.assert_equal(
get_shape_list(global_attn_output),
[
batch_size * self._num_heads, max_num_global_attn_indices,
self._key_dim
],
message=f"global_attn_output tensor has the wrong size. Size should be "
f"{(batch_size * self._num_heads, max_num_global_attn_indices, self._key_dim)}, "
f"but is {get_shape_list(global_attn_output)}.",
)
global_attn_output = tf.reshape(
global_attn_output,
(batch_size, self._num_heads, max_num_global_attn_indices,
self._key_dim),
)
# get only non zero global attn output
nonzero_global_attn_output = tf.gather_nd(
tf.transpose(global_attn_output, (0, 2, 1, 3)),
is_local_index_global_attn_nonzero,
)
nonzero_global_attn_output = tf.reshape(
nonzero_global_attn_output,
(get_shape_list(is_local_index_global_attn_nonzero)[0], -1),
)
# overwrite values with global attention
attn_output = tf.tensor_scatter_nd_update(attn_output,
is_index_global_attn_nonzero,
nonzero_global_attn_output)
global_attn_probs = tf.reshape(
global_attn_probs,
(batch_size, self._num_heads, max_num_global_attn_indices, seq_len))
attn_output = self._output_dense(attn_output)
return attn_output, global_attn_probs
def reshape_and_transpose(self, vector, batch_size):
return tf.reshape(
tf.transpose(
tf.reshape(vector,
(batch_size, -1, self._num_heads, self._key_dim)),
(0, 2, 1, 3),
),
(batch_size * self._num_heads, -1, self._key_dim),
)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.projects.longformer.longformer_attention."""
import numpy as np
import tensorflow as tf
from official.modeling.tf_utils import get_shape_list
from official.projects.longformer import longformer_attention
def _create_mock_attention_data(num_heads,
key_dim,
value_dim,
q_seq_length,
kv_seq_length,
batch_size,
include_mask=False):
"""Creates mock testing data.
Args:
num_heads: `int`, Number of attention heads.
key_dim: `int`, Size of query head.
value_dim: `int`, Size of key, value dim.
q_seq_length: `int`, query sequence length of the input.
kv_seq_length: `int`, key, value sequence length of the input.
batch_size: `int`, the batch size.
include_mask: optional `bool`, whether or not to include mask data.
Returns:
A dictionary with `str` as keys and `Tensor` as values.
"""
query_shape = (batch_size, q_seq_length, key_dim)
value_shape = (batch_size, kv_seq_length, value_dim)
data = dict(
query=tf.random.normal(shape=query_shape),
value=tf.random.normal(shape=value_shape),
key=tf.random.normal(shape=value_shape))
total_seq_length = kv_seq_length
if include_mask:
mask_shape = (batch_size, num_heads, q_seq_length, total_seq_length)
mask_data = np.random.randint(2, size=mask_shape).astype('float32')
mask_data = dict(attention_mask=mask_data)
data.update(mask_data)
return data
class LongformerAttentionTest(tf.test.TestCase):
def setUp(self):
super(LongformerAttentionTest, self).setUp()
np.random.seed(0)
tf.random.set_seed(0)
def _get_hidden_states(self):
return tf.convert_to_tensor(
[[
[
4.98332758e-01,
2.69175139e00,
-7.08081422e-03,
1.04915401e00,
-1.83476661e00,
7.67220476e-01,
2.98580543e-01,
2.84803992e-02,
],
[
-7.58357372e-01,
4.20635998e-01,
-4.04739919e-02,
1.59924145e-01,
2.05135748e00,
-1.15997978e00,
5.37166397e-01,
2.62873606e-01,
],
[
-1.69438001e00,
4.17574660e-01,
-1.49196962e00,
-1.76483717e00,
-1.94566312e-01,
-1.71183858e00,
7.72903565e-01,
-1.11557056e00,
],
[
5.44028163e-01,
2.05466114e-01,
-3.63045868e-01,
2.41865062e-01,
3.20348382e-01,
-9.05611176e-01,
-1.92690727e-01,
-1.19917547e00,
],
]],
dtype=tf.float32,
)
def test_diagonalize(self):
hidden_states = self._get_hidden_states()
hidden_states = tf.reshape(hidden_states,
(1, 8, 4)) # set seq length = 8, hidden dim = 4
chunked_hidden_states = longformer_attention.LongformerAttention._chunk(
hidden_states, window_overlap=2)
window_overlap_size = get_shape_list(chunked_hidden_states)[2]
self.assertEqual(window_overlap_size, 4)
padded_hidden_states = longformer_attention.LongformerAttention._pad_and_diagonalize(
chunked_hidden_states)
self.assertEqual(
get_shape_list(padded_hidden_states)[-1],
get_shape_list(chunked_hidden_states)[-1] + window_overlap_size - 1)
# first row => [0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000]
tf.debugging.assert_near(
padded_hidden_states[0, 0, 0, :4],
chunked_hidden_states[0, 0, 0],
rtol=1e-3)
tf.debugging.assert_near(
padded_hidden_states[0, 0, 0, 4:],
tf.zeros((3,), dtype=tf.dtypes.float32),
rtol=1e-3)
# last row => [0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629]
tf.debugging.assert_near(
padded_hidden_states[0, 0, -1, 3:],
chunked_hidden_states[0, 0, -1],
rtol=1e-3)
tf.debugging.assert_near(
padded_hidden_states[0, 0, -1, :3],
tf.zeros((3,), dtype=tf.dtypes.float32),
rtol=1e-3)
def test_pad_and_transpose_last_two_dims(self):
hidden_states = self._get_hidden_states()
self.assertTrue(get_shape_list(hidden_states), [1, 8, 4])
# pad along seq length dim
paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]],
dtype=tf.dtypes.int32)
hidden_states = longformer_attention.LongformerAttention._chunk(
hidden_states, window_overlap=2)
padded_hidden_states = longformer_attention.LongformerAttention._pad_and_transpose_last_two_dims(
hidden_states, paddings)
self.assertEqual(get_shape_list(padded_hidden_states), [1, 1, 8, 5])
expected_added_dim = tf.zeros((5,), dtype=tf.dtypes.float32)
tf.debugging.assert_near(
expected_added_dim, padded_hidden_states[0, 0, -1, :], rtol=1e-6)
tf.debugging.assert_near(
hidden_states[0, 0, -1, :],
tf.reshape(padded_hidden_states, (1, -1))[0, 24:32],
rtol=1e-6)
def test_mask_invalid_locations(self):
hidden_states = self._get_hidden_states()
batch_size = 1
seq_length = 8
hidden_size = 4
hidden_states = tf.reshape(hidden_states,
(batch_size, seq_length, hidden_size))
hidden_states = longformer_attention.LongformerAttention._chunk(
hidden_states, window_overlap=2)
hid_states_1 = longformer_attention.LongformerAttention._mask_invalid_locations(
hidden_states, 1)
hid_states_2 = longformer_attention.LongformerAttention._mask_invalid_locations(
hidden_states, 2)
hid_states_3 = longformer_attention.LongformerAttention._mask_invalid_locations(
hidden_states[:, :, :, :3], 2)
hid_states_4 = longformer_attention.LongformerAttention._mask_invalid_locations(
hidden_states[:, :, 2:, :], 2)
self.assertEqual(
tf.math.reduce_sum(
tf.cast(tf.math.is_inf(hid_states_1), tf.dtypes.int32)), 8)
self.assertEqual(
tf.math.reduce_sum(
tf.cast(tf.math.is_inf(hid_states_2), tf.dtypes.int32)), 24)
self.assertEqual(
tf.math.reduce_sum(
tf.cast(tf.math.is_inf(hid_states_3), tf.dtypes.int32)), 24)
self.assertEqual(
tf.math.reduce_sum(
tf.cast(tf.math.is_inf(hid_states_4), tf.dtypes.int32)), 12)
def test_chunk(self):
hidden_states = self._get_hidden_states()
batch_size = 1
seq_length = 8
hidden_size = 4
hidden_states = tf.reshape(hidden_states,
(batch_size, seq_length, hidden_size))
chunked_hidden_states = longformer_attention.LongformerAttention._chunk(
hidden_states, window_overlap=2)
# expected slices across chunk and seq length dim
expected_slice_along_seq_length = tf.convert_to_tensor(
[0.4983, -0.7584, -1.6944], dtype=tf.dtypes.float32)
expected_slice_along_chunk = tf.convert_to_tensor(
[0.4983, -1.8348, -0.7584, 2.0514], dtype=tf.dtypes.float32)
self.assertEqual(get_shape_list(chunked_hidden_states), [1, 3, 4, 4])
tf.debugging.assert_near(
chunked_hidden_states[0, :, 0, 0],
expected_slice_along_seq_length,
rtol=1e-3)
tf.debugging.assert_near(
chunked_hidden_states[0, 0, :, 0],
expected_slice_along_chunk,
rtol=1e-3)
def test_layer_local_attn(self):
hidden_states = self._get_hidden_states()
batch_size, seq_length, _ = hidden_states.shape
layer = longformer_attention.LongformerAttention(
num_heads=2,
key_dim=4,
value_dim=4,
layer_id=0,
attention_window=4,
global_attention_size=0,
)
attention_mask = tf.zeros((batch_size, seq_length), dtype=tf.dtypes.float32)
is_index_global_attn = tf.math.greater(attention_mask, 1)
attention_mask = tf.where(
tf.range(4)[None, :, None, None] > 1, -10000.0,
attention_mask[:, :, None, None])
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
output_hidden_states = layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
)[0]
self.assertTrue(output_hidden_states.shape, (1, 4, 8))
def test_layer_global_attn(self):
layer = longformer_attention.LongformerAttention(
num_heads=2,
key_dim=4,
value_dim=4,
layer_id=0,
attention_window=4,
global_attention_size=1,
)
hidden_states = self._get_hidden_states()
hidden_states = tf.concat(
[self._get_hidden_states(),
self._get_hidden_states() - 0.5], axis=0)
_, seq_length, _ = hidden_states.shape
# create attn mask
attention_mask_1 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask_2 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask_1 = tf.where(
tf.range(4)[None, :, None, None] == 0, 10000.0, attention_mask_1)
attention_mask_1 = tf.where(
tf.range(4)[None, :, None, None] > 2, -10000.0, attention_mask_1)
attention_mask_2 = tf.where(
tf.range(4)[None, :, None, None] == 0, 10000.0, attention_mask_2)
attention_mask = tf.concat([attention_mask_1, attention_mask_2], axis=0)
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0)
output_hidden_states = layer(
hidden_states=hidden_states,
attention_mask=-tf.math.abs(attention_mask),
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
)[0]
self.assertTrue(output_hidden_states.shape, (2, 4, 8))
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Longformer encoder. Modified From huggingface/transformers."""
# pylint: disable=g-classes-have-attributes
from typing import Any, Callable, List, Optional, Union
from absl import logging
import tensorflow as tf
from official.modeling.tf_utils import get_shape_list
from official.nlp.modeling import layers
from official.projects.longformer.longformer_encoder_block import LongformerEncoderBlock
_Initializer = Union[str, tf.keras.initializers.Initializer]
_approx_gelu = lambda x: tf.keras.activations.gelu(x, approximate=True)
class LongformerEncoder(tf.keras.layers.Layer):
"""LongformerEncoder.
Args:
vocab_size: The size of the token vocabulary.
attention_window: list of ints representing the window size for each layer.
global_attention_size: the size of global attention used for each token.
pad_token_id: the token id for the pad token
hidden_size: The size of the transformer hidden layers.
num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads.
max_sequence_length: The maximum sequence length that this encoder can
consume. If None, max_sequence_length uses the value from sequence length.
This determines the variable shape for positional embeddings.
type_vocab_size: The number of types that the 'type_ids' input can take.
inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network for each transformer.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network for each transformer.
output_dropout: Dropout probability for the post-attention and output
dropout.
attention_dropout: The dropout rate to use for the attention layers within
the transformer layers.
initializer: The initialzer to use for all weights in this encoder.
output_range: The sequence output range, [0, output_range), by slicing the
target sequence of the last transformer layer. `None` means the entire
target sequence will attend to the source sequence, which yields the full
output.
embedding_width: The width of the word embeddings. If the embedding width is
not equal to hidden size, embedding parameters will be factorized into two
matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size').
embedding_layer: An optional Layer instance which will be called to generate
embeddings for the input word IDs.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
"""
def __init__(
self,
vocab_size: int,
attention_window: Union[List[int], int] = 512,
global_attention_size: int = 0,
pad_token_id: int = 1,
hidden_size: int = 768,
num_layers: int = 12,
num_attention_heads: int = 12,
max_sequence_length: int = 512,
type_vocab_size: int = 16,
inner_dim: int = 3072,
inner_activation: Callable[..., Any] = _approx_gelu,
output_dropout: float = 0.1,
attention_dropout: float = 0.1,
initializer: _Initializer = tf.keras.initializers.TruncatedNormal(
stddev=0.02),
output_range: Optional[int] = None,
embedding_width: Optional[int] = None,
embedding_layer: Optional[tf.keras.layers.Layer] = None,
norm_first: bool = False,
**kwargs):
super().__init__(**kwargs)
# Longformer args
self._attention_window = attention_window
self._global_attention_size = global_attention_size
self._pad_token_id = pad_token_id
activation = tf.keras.activations.get(inner_activation)
initializer = tf.keras.initializers.get(initializer)
if embedding_width is None:
embedding_width = hidden_size
if embedding_layer is None:
self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=initializer,
name='word_embeddings')
else:
self._embedding_layer = embedding_layer
self._position_embedding_layer = layers.PositionEmbedding(
initializer=initializer,
max_length=max_sequence_length,
name='position_embedding')
self._type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=initializer,
use_one_hot=True,
name='type_embeddings')
self._embedding_norm_layer = tf.keras.layers.LayerNormalization(
name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32)
self._embedding_dropout = tf.keras.layers.Dropout(
rate=output_dropout, name='embedding_dropout')
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
self._embedding_projection = None
if embedding_width != hidden_size:
self._embedding_projection = tf.keras.layers.experimental.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
kernel_initializer=initializer,
name='embedding_projection')
self._transformer_layers = []
self._attention_mask_layer = layers.SelfAttentionMask(
name='self_attention_mask')
for i in range(num_layers):
layer = LongformerEncoderBlock(
global_attention_size=global_attention_size,
num_attention_heads=num_attention_heads,
inner_dim=inner_dim,
inner_activation=inner_activation,
attention_window=attention_window[i],
layer_id=i,
output_dropout=output_dropout,
attention_dropout=attention_dropout,
norm_first=norm_first,
output_range=output_range if i == num_layers - 1 else None,
kernel_initializer=initializer,
name=f'transformer/layer_{i}')
self._transformer_layers.append(layer)
self._pooler_layer = tf.keras.layers.Dense(
units=hidden_size,
activation='tanh',
kernel_initializer=initializer,
name='pooler_transform')
self._config = {
'vocab_size': vocab_size,
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_attention_heads': num_attention_heads,
'max_sequence_length': max_sequence_length,
'type_vocab_size': type_vocab_size,
'inner_dim': inner_dim,
'inner_activation': tf.keras.activations.serialize(activation),
'output_dropout': output_dropout,
'attention_dropout': attention_dropout,
'initializer': tf.keras.initializers.serialize(initializer),
'output_range': output_range,
'embedding_width': embedding_width,
'embedding_layer': embedding_layer,
'norm_first': norm_first,
'attention_window': attention_window,
'global_attention_size': global_attention_size,
'pad_token_id': pad_token_id,
}
self.inputs = dict(
input_word_ids=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_mask=tf.keras.Input(shape=(None,), dtype=tf.int32),
input_type_ids=tf.keras.Input(shape=(None,), dtype=tf.int32))
def call(self, inputs):
word_embeddings = None
if isinstance(inputs, dict):
word_ids = inputs.get('input_word_ids') # input_ids
mask = inputs.get('input_mask') # attention_mask
type_ids = inputs.get('input_type_ids') # token_type_ids
word_embeddings = inputs.get('input_word_embeddings',
None) # input_embeds
else:
raise ValueError(f'Unexpected inputs type to {self.__class__}.')
(
padding_len,
word_ids,
mask,
type_ids,
word_embeddings,
) = self._pad_to_window_size(
word_ids=word_ids,
mask=mask,
type_ids=type_ids,
word_embeddings=word_embeddings,
pad_token_id=self._pad_token_id)
if word_embeddings is None:
word_embeddings = self._embedding_layer(word_ids)
# absolute position embeddings.
position_embeddings = self._position_embedding_layer(word_embeddings)
type_embeddings = self._type_embedding_layer(type_ids)
embeddings = word_embeddings + position_embeddings + type_embeddings
embeddings = self._embedding_norm_layer(embeddings)
embeddings = self._embedding_dropout(embeddings)
if self._embedding_projection is not None:
embeddings = self._embedding_projection(embeddings)
batch_size, seq_len = get_shape_list(mask)
# create masks with fixed len global_attention_size
mask = tf.transpose(
tf.concat(
values=[
tf.ones(
(self._global_attention_size, batch_size), tf.int32) * 2,
tf.transpose(mask)[self._global_attention_size:]
],
axis=0))
is_index_masked = tf.math.less(mask, 1)
is_index_global_attn = tf.transpose(
tf.concat(
values=[
tf.ones((self._global_attention_size, batch_size), tf.bool),
tf.zeros((seq_len - self._global_attention_size, batch_size),
tf.bool)
],
axis=0))
# Longformer
attention_mask = mask
extended_attention_mask = tf.reshape(
attention_mask, (tf.shape(mask)[0], tf.shape(mask)[1], 1, 1))
attention_mask = tf.cast(
tf.math.abs(1 - extended_attention_mask), tf.dtypes.float32) * -10000.0
encoder_outputs = []
x = embeddings
# TFLongformerEncoder
for layer in self._transformer_layers:
x = layer([x, attention_mask, is_index_masked, is_index_global_attn])
encoder_outputs.append(x)
last_encoder_output = encoder_outputs[-1]
if padding_len > 0:
last_encoder_output = last_encoder_output[:, :-padding_len]
first_token_tensor = last_encoder_output[:, 0, :]
pooled_output = self._pooler_layer(first_token_tensor)
return dict(
sequence_output=last_encoder_output,
pooled_output=pooled_output,
encoder_outputs=encoder_outputs)
def get_embedding_table(self):
return self._embedding_layer.embeddings
def get_embedding_layer(self):
return self._embedding_layer
def get_config(self):
return dict(self._config)
@property
def transformer_layers(self):
"""List of Transformer layers in the encoder."""
return self._transformer_layers
@property
def pooler_layer(self):
"""The pooler dense layer after the transformer layers."""
return self._pooler_layer
@classmethod
def from_config(cls, config, custom_objects=None):
if 'embedding_layer' in config and config['embedding_layer'] is not None:
warn_string = (
'You are reloading a model that was saved with a '
'potentially-shared embedding layer object. If you contine to '
'train this model, the embedding layer will no longer be shared. '
'To work around this, load the model outside of the Keras API.')
print('WARNING: ' + warn_string)
logging.warn(warn_string)
return cls(**config)
def _pad_to_window_size(
self,
word_ids,
mask,
type_ids,
word_embeddings,
pad_token_id,
):
# padding
attention_window = max(self._attention_window)
assert (attention_window %
2 == 0), ('`attention_window` should be an even value.'
f'Given {attention_window}')
input_shape = get_shape_list(
word_ids) if word_ids is not None else get_shape_list(word_embeddings)
batch_size, seq_len = input_shape[:2]
if seq_len is not None:
padding_len = (attention_window -
seq_len % attention_window) % attention_window
else:
padding_len = 0
paddings = tf.convert_to_tensor([[0, 0], [0, padding_len]])
if word_ids is not None:
word_ids = tf.pad(word_ids, paddings, constant_values=pad_token_id)
if word_embeddings is not None:
def pad_embeddings():
word_ids_padding = tf.fill((batch_size, padding_len), self.pad_token_id)
word_embeddings_padding = self._embedding_layer(word_ids_padding)
return tf.concat([word_embeddings, word_embeddings_padding], axis=-2)
word_embeddings = tf.cond(
tf.math.greater(padding_len, 0), pad_embeddings,
lambda: word_embeddings)
mask = tf.pad(
mask, paddings,
constant_values=False) # no attention on the padding tokens
token_type_ids = tf.pad(
type_ids, paddings, constant_values=0) # pad with token_type_id = 0
return (
padding_len,
word_ids,
mask,
token_type_ids,
word_embeddings,
)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Longformer attention layer. Modified From huggingface/transformers."""
import tensorflow as tf
from official.projects.longformer.longformer_attention import LongformerAttention
@tf.keras.utils.register_keras_serializable(package="Text")
class LongformerEncoderBlock(tf.keras.layers.Layer):
"""LongformerEncoderBlock.
Args:
num_attention_heads: Number of attention heads.
inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network.
output_range: the sequence output range, [0, output_range) for slicing the
target sequence. `None` means the target sequence is not sliced.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
output_dropout: Dropout probability for the post-attention and output
dropout.
attention_dropout: Dropout probability for within the attention layer.
inner_dropout: Dropout probability for the first Dense layer in a
two-layer feedforward network.
attention_initializer: Initializer for kernels of attention layers. If set
`None`, attention layers use kernel_initializer as initializer for
kernel.
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
**kwargs: keyword arguments/
"""
def __init__(
self,
global_attention_size,
num_attention_heads,
inner_dim,
inner_activation,
# Longformer
attention_window,
layer_id=0,
output_range=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
output_dropout=0.0,
attention_dropout=0.0,
inner_dropout=0.0,
attention_initializer=None,
attention_axes=None,
**kwargs):
super().__init__(**kwargs)
self.global_attention_size = global_attention_size
self._num_heads = num_attention_heads
self._inner_dim = inner_dim
self._inner_activation = inner_activation
# Longformer
self._attention_window = attention_window
self._layer_id = layer_id
self._attention_dropout = attention_dropout
self._attention_dropout_rate = attention_dropout
self._output_dropout = output_dropout
self._output_dropout_rate = output_dropout
self._output_range = output_range
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
self._inner_dropout = inner_dropout
if attention_initializer:
self._attention_initializer = tf.keras.initializers.get(
attention_initializer)
else:
self._attention_initializer = self._kernel_initializer
self._attention_axes = attention_axes
def build(self, input_shape):
if isinstance(input_shape, tf.TensorShape):
input_tensor_shape = input_shape
elif isinstance(input_shape, (list, tuple)):
input_tensor_shape = tf.TensorShape(input_shape[0])
else:
raise ValueError(
f"The type of input shape argument is not supported, got: "
f"{type(input_shape)}")
einsum_equation = "abc,cd->abd"
if len(input_tensor_shape.as_list()) > 3:
einsum_equation = "...bc,cd->...bd"
hidden_size = input_tensor_shape[-1]
if hidden_size % self._num_heads != 0:
raise ValueError(
f"The input size ({hidden_size}) is not a multiple of the number of attention "
f"heads ({self._num_heads})")
self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
# TFLongformerSelfAttention + TFLongformerSelfOutput.dense
self._attention_layer = LongformerAttention(
# Longformer
layer_id=self._layer_id,
global_attention_size=self.global_attention_size,
attention_window=self._attention_window,
num_heads=self._num_heads,
key_dim=self._attention_head_size,
dropout=self._attention_dropout,
use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
attention_axes=self._attention_axes,
name="self_attention",
**common_kwargs)
# TFLongformerSelfOutput.dropout
self._attention_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
# Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet.
# TFLongformerSelfOutput.Layernorm
self._attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32))
# TFLongformerIntermediate
# TFLongformerIntermediate.dense
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=(None, self._inner_dim),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
name="intermediate",
**common_kwargs)
policy = tf.keras.mixed_precision.global_policy()
if policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge
# as well, so we use float32.
# TODO(b/154538392): Investigate this.
policy = tf.float32
# TFLongformerIntermediate.intermediate_act_fn
self._intermediate_activation_layer = tf.keras.layers.Activation(
self._inner_activation, dtype=policy)
self._inner_dropout_layer = tf.keras.layers.Dropout(
rate=self._inner_dropout)
# TFLongformerOutput
# TFLongformerOutput.dense
self._output_dense = tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
kernel_initializer=self._kernel_initializer,
**common_kwargs)
# TFLongformerOutput.dropout
self._output_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
# Use float32 in layernorm for numeric stability.
# TFLongformerOutput.layernorm
self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32)
super().build(input_shape)
def get_config(self):
config = {
"num_attention_heads":
self._num_heads,
"inner_dim":
self._inner_dim,
"inner_activation":
self._inner_activation,
"output_dropout":
self._output_dropout_rate,
"attention_dropout":
self._attention_dropout_rate,
"output_range":
self._output_range,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint),
"use_bias":
self._use_bias,
"norm_first":
self._norm_first,
"norm_epsilon":
self._norm_epsilon,
"inner_dropout":
self._inner_dropout,
"attention_initializer":
tf.keras.initializers.serialize(self._attention_initializer),
"attention_axes":
self._attention_axes,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
"""Transformer self-attention encoder block call.
Args:
inputs: a single tensor or a list of tensors. `input tensor` as the single
sequence of embeddings. [`input tensor`, `attention mask`] to have the
additional attention mask. [`query tensor`, `key value tensor`,
`attention mask`] to have separate input streams for the query, and
key/value to the multi-head attention.
Returns:
An output tensor with the same dimensions as input/query tensor.
"""
if isinstance(inputs, (list, tuple)):
if len(inputs) == 4:
(
input_tensor,
attention_mask,
is_index_masked,
is_index_global_attn,
) = inputs
key_value = None
elif len(inputs) == 5:
assert False # No key_value
else:
raise ValueError(
f"Unexpected inputs to {self.__class__} with length at {len(inputs)}"
)
else:
input_tensor = inputs
attention_mask = None
is_index_masked = None
is_index_global_attn = None
key_value = None
if self._output_range:
if self._norm_first:
source_tensor = input_tensor[:, 0:self._output_range, :]
input_tensor = self._attention_layer_norm(input_tensor)
if key_value is not None:
key_value = self._attention_layer_norm(key_value)
target_tensor = input_tensor[:, 0:self._output_range, :]
if attention_mask is not None:
attention_mask = attention_mask[:, 0:self._output_range, :]
if is_index_masked is not None:
is_index_masked = is_index_masked[:, 0:self._output_range]
if is_index_global_attn is not None:
is_index_global_attn = is_index_global_attn[:, 0:self._output_range]
else:
if self._norm_first:
source_tensor = input_tensor
input_tensor = self._attention_layer_norm(input_tensor)
if key_value is not None:
key_value = self._attention_layer_norm(key_value)
target_tensor = input_tensor
if key_value is None:
key_value = input_tensor
attention_output = self._attention_layer(
hidden_states=target_tensor,
attention_mask=attention_mask,
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
)
# TFLongformerAttention.TFLongformerSelfOutput.* - {.dense}
attention_output = self._attention_dropout(attention_output)
if self._norm_first:
attention_output = source_tensor + attention_output
else:
attention_output = self._attention_layer_norm(target_tensor +
attention_output)
if self._norm_first:
source_attention_output = attention_output
attention_output = self._output_layer_norm(attention_output)
# TFLongformerIntermediate
inner_output = self._intermediate_dense(attention_output)
inner_output = self._intermediate_activation_layer(inner_output)
inner_output = self._inner_dropout_layer(inner_output)
# TFLongformerOutput
layer_output = self._output_dense(inner_output)
layer_output = self._output_dropout(layer_output)
if self._norm_first:
return source_attention_output + layer_output
# During mixed precision training, layer norm output is always fp32 for now.
# Casts fp32 for the subsequent add.
layer_output = tf.cast(layer_output, tf.float32)
return self._output_layer_norm(layer_output + attention_output)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.projects.longformer.longformer_encoder."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from official.projects.longformer.longformer_encoder import LongformerEncoder
class LongformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
def setUp(self):
super(LongformerEncoderTest, self).setUp()
np.random.seed(0)
tf.random.set_seed(0)
@combinations.generate(
combinations.combine(
attention_window=[32, 128], global_attention_size=[0, 1, 2]))
def test_encoder(self, attention_window, global_attention_size):
sequence_length = 128
batch_size = 2
vocab_size = 1024
hidden_size = 256
network = LongformerEncoder(
global_attention_size=global_attention_size,
vocab_size=vocab_size,
attention_window=[attention_window],
hidden_size=hidden_size,
num_layers=1,
num_attention_heads=4,
max_sequence_length=512)
word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length), dtype=np.int32)
mask_data = np.random.randint(
2, size=(batch_size, sequence_length), dtype=np.int32)
type_id_data = np.random.randint(
2, size=(batch_size, sequence_length), dtype=np.int32)
inputs = {
'input_word_ids': word_id_data,
'input_mask': mask_data,
'input_type_ids': type_id_data,
}
outputs = network(inputs)
self.assertEqual(outputs['sequence_output'].shape,
(batch_size, sequence_length, hidden_size))
@combinations.generate(
combinations.combine(
norm_first=[True, False], global_attention_size=[0, 1, 2]))
def test_norm_first(self, norm_first, global_attention_size):
sequence_length = 128
batch_size = 2
vocab_size = 1024
hidden_size = 256
network = LongformerEncoder(
global_attention_size=global_attention_size,
vocab_size=vocab_size,
attention_window=[32],
hidden_size=hidden_size,
num_layers=1,
num_attention_heads=4,
max_sequence_length=512,
norm_first=norm_first)
word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length), dtype=np.int32)
mask_data = np.random.randint(
2, size=(batch_size, sequence_length), dtype=np.int32)
type_id_data = np.random.randint(
2, size=(batch_size, sequence_length), dtype=np.int32)
inputs = {
'input_word_ids': word_id_data,
'input_mask': mask_data,
'input_type_ids': type_id_data,
}
outputs = network(inputs)
self.assertEqual(outputs['sequence_output'].shape,
(batch_size, sequence_length, hidden_size))
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Longformer experiments."""
# pylint: disable=g-doc-return-or-yield,line-too-long
import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import optimization
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.data import pretrain_dataloader
from official.nlp.data import sentence_prediction_dataloader
from official.nlp.tasks import masked_lm
from official.nlp.tasks import sentence_prediction
from official.projects.longformer.longformer import LongformerEncoderConfig
AdamWeightDecay = optimization.AdamWeightDecayConfig
PolynomialLr = optimization.PolynomialLrConfig
PolynomialWarmupConfig = optimization.PolynomialWarmupConfig
@dataclasses.dataclass
class LongformerOptimizationConfig(optimization.OptimizationConfig):
"""Longformer optimization configuration."""
optimizer: optimization.OptimizerConfig = optimization.OptimizerConfig(
type='adamw',
adamw=AdamWeightDecay(
weight_decay_rate=0.01,
exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'],
epsilon=1e-6))
learning_rate: optimization.LrConfig = optimization.LrConfig(
type='polynomial',
polynomial=PolynomialLr(
initial_learning_rate=1e-4,
decay_steps=1000000,
end_learning_rate=0.0))
warmup: optimization.WarmupConfig = optimization.WarmupConfig(
type='polynomial', polynomial=PolynomialWarmupConfig(warmup_steps=10000))
@exp_factory.register_config_factory('longformer/pretraining')
def longformer_pretraining() -> cfg.ExperimentConfig:
"""Longformer pretraining experiment."""
config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(enable_xla=True),
task=masked_lm.MaskedLMConfig(
model=bert.PretrainerConfig(
encoder=encoders.EncoderConfig(
type='any', any=LongformerEncoderConfig()),
cls_heads=[
bert.ClsHeadConfig(
inner_dim=768,
num_classes=2,
dropout_rate=0.1,
name='next_sentence')
]),
train_data=pretrain_dataloader.BertPretrainDataConfig(
use_v2_feature_names=True),
validation_data=pretrain_dataloader.BertPretrainDataConfig(
use_v2_feature_names=True, is_training=False)),
trainer=cfg.TrainerConfig(
optimizer_config=LongformerOptimizationConfig(), train_steps=1000000),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('longformer/glue')
def longformer_glue() -> cfg.ExperimentConfig:
"""Longformer glue fine-tuning."""
config = cfg.ExperimentConfig(
task=sentence_prediction.SentencePredictionConfig(
model=sentence_prediction.ModelConfig(
encoder=encoders.EncoderConfig(
type='any', any=LongformerEncoderConfig())),
train_data=sentence_prediction_dataloader
.SentencePredictionDataConfig(),
validation_data=sentence_prediction_dataloader
.SentencePredictionDataConfig(
is_training=False, drop_remainder=False)),
trainer=cfg.TrainerConfig(
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate':
0.01,
'exclude_from_weight_decay':
['LayerNorm', 'layer_norm', 'bias'],
}
},
'learning_rate': {
'type': 'polynomial',
'polynomial': {
'initial_learning_rate': 3e-5,
'end_learning_rate': 0.0,
}
},
'warmup': {
'type': 'polynomial'
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
...@@ -12,22 +12,19 @@ ...@@ -12,22 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3 """A customized training library for the specific task."""
"""TensorFlow Model Garden Vision training driver."""
from absl import app from absl import app
from absl import flags from absl import flags
import gin import gin
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.common import distribute_utils from official.common import distribute_utils
from official.common import flags as tfm_flags from official.common import flags as tfm_flags
from official.core import task_factory from official.core import task_factory
from official.core import train_lib from official.core import train_lib
from official.core import train_utils from official.core import train_utils
from official.modeling import performance from official.modeling import performance
from official.projects.longformer import longformer_experiments # pylint: disable=unused-import
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -51,7 +48,9 @@ def main(_): ...@@ -51,7 +48,9 @@ def main(_):
distribution_strategy=params.runtime.distribution_strategy, distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg, all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus, num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu) tpu_address=params.runtime.tpu,
**params.runtime.model_parallelism())
with distribution_strategy.scope(): with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir) task = task_factory.get_task(params.task, logging_dir=model_dir)
...@@ -64,7 +63,7 @@ def main(_): ...@@ -64,7 +63,7 @@ def main(_):
train_utils.save_gin_config(FLAGS.mode, model_dir) train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__': if __name__ == '__main__':
tfm_flags.define_flags() tfm_flags.define_flags()
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
app.run(main) app.run(main)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Converts pre-trained pytorch checkpoint into a tf encoder checkpoint."""
import os
from absl import app
import numpy as np
import tensorflow as tf
import transformers
from official.modeling import tf_utils
from official.projects.longformer.longformer import LongformerEncoderConfig
from official.projects.longformer.longformer_encoder import LongformerEncoder
def _get_pytorch_longformer_model():
pretrained_lm = "allenai/longformer-base-4096"
model = transformers.AutoModel.from_pretrained(pretrained_lm)
return {n: p.data.numpy() for n, p in model.named_parameters()}
def _create_longformer_model():
"""Creates a Longformer model."""
encoder_cfg = LongformerEncoderConfig
encoder_cfg.vocab_size = 50265
encoder_cfg.max_position_embeddings = 4098
encoder_cfg.attention_window = [2] * encoder_cfg.num_layers
encoder_cfg.global_attention_size = 1
encoder = LongformerEncoder(
attention_window=encoder_cfg.attention_window,
global_attention_size=encoder_cfg.global_attention_size,
vocab_size=encoder_cfg.vocab_size,
hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers,
num_attention_heads=encoder_cfg.num_attention_heads,
inner_dim=encoder_cfg.intermediate_size,
inner_activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
output_dropout=encoder_cfg.dropout_rate,
attention_dropout=encoder_cfg.attention_dropout_rate,
max_sequence_length=encoder_cfg.max_position_embeddings,
type_vocab_size=encoder_cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
output_range=encoder_cfg.output_range,
embedding_width=encoder_cfg.embedding_size,
norm_first=encoder_cfg.norm_first)
return encoder
# pylint: disable=protected-access
def convert(encoder, allenai_model):
"""Convert AllenAI Longformer to the one in the codebase."""
num_layers = encoder._config["num_layers"]
num_attention_heads = encoder._config["num_attention_heads"]
hidden_size = encoder._config["hidden_size"]
head_size = hidden_size // num_attention_heads
assert head_size * num_attention_heads == hidden_size
encoder._embedding_layer.set_weights(
[allenai_model["embeddings.word_embeddings.weight"]])
encoder._embedding_norm_layer.set_weights([
allenai_model["embeddings.LayerNorm.weight"],
allenai_model["embeddings.LayerNorm.bias"]
])
encoder._type_embedding_layer.set_weights([
np.repeat(
allenai_model["embeddings.token_type_embeddings.weight"], 2, axis=0)
])
encoder._position_embedding_layer.set_weights(
[allenai_model["embeddings.position_embeddings.weight"]])
encoder._pooler_layer.set_weights([
allenai_model["pooler.dense.weight"], allenai_model["pooler.dense.bias"]
])
for layer_num in range(num_layers):
encoder._transformer_layers[
layer_num]._attention_layer._global_key_dense.set_weights([
allenai_model[
f"encoder.layer.{layer_num}.attention.self.key_global.weight"].T
.reshape(
(hidden_size, num_attention_heads, head_size)), allenai_model[
f"encoder.layer.{layer_num}.attention.self.key_global.bias"]
.reshape((num_attention_heads, head_size))
])
encoder._transformer_layers[
layer_num]._attention_layer._global_query_dense.set_weights([
allenai_model[
f"encoder.layer.{layer_num}.attention.self.query_global.weight"]
.T.reshape((hidden_size, num_attention_heads, head_size)),
allenai_model[
f"encoder.layer.{layer_num}.attention.self.query_global.bias"]
.reshape((num_attention_heads, head_size))
])
encoder._transformer_layers[
layer_num]._attention_layer._global_value_dense.set_weights([
allenai_model[
f"encoder.layer.{layer_num}.attention.self.value_global.weight"]
.T.reshape((hidden_size, num_attention_heads, head_size)),
allenai_model[
f"encoder.layer.{layer_num}.attention.self.value_global.bias"]
.reshape((num_attention_heads, head_size))
])
encoder._transformer_layers[
layer_num]._attention_layer._key_dense.set_weights([
allenai_model[
f"encoder.layer.{layer_num}.attention.self.key.weight"].T
.reshape(
(hidden_size, num_attention_heads, head_size)), allenai_model[
f"encoder.layer.{layer_num}.attention.self.key_global.bias"]
.reshape((num_attention_heads, head_size))
])
encoder._transformer_layers[
layer_num]._attention_layer._query_dense.set_weights([
allenai_model[
f"encoder.layer.{layer_num}.attention.self.query.weight"].T
.reshape((hidden_size, num_attention_heads, head_size)),
allenai_model[
f"encoder.layer.{layer_num}.attention.self.query.bias"].reshape(
(num_attention_heads, head_size))
])
encoder._transformer_layers[
layer_num]._attention_layer._value_dense.set_weights([
allenai_model[
f"encoder.layer.{layer_num}.attention.self.value.weight"].T
.reshape((hidden_size, num_attention_heads, head_size)),
allenai_model[
f"encoder.layer.{layer_num}.attention.self.value.bias"].reshape(
(num_attention_heads, head_size))
])
encoder._transformer_layers[
layer_num]._attention_layer._output_dense.set_weights([
allenai_model[
f"encoder.layer.{layer_num}.attention.output.dense.weight"].T,
allenai_model[
f"encoder.layer.{layer_num}.attention.output.dense.bias"]
])
encoder._transformer_layers[layer_num]._attention_layer_norm.set_weights([
allenai_model[
f"encoder.layer.{layer_num}.attention.output.LayerNorm.weight"],
allenai_model[
f"encoder.layer.{layer_num}.attention.output.LayerNorm.bias"]
])
encoder._transformer_layers[layer_num]._intermediate_dense.set_weights([
allenai_model[f"encoder.layer.{layer_num}.intermediate.dense.weight"].T,
allenai_model[f"encoder.layer.{layer_num}.intermediate.dense.bias"]
])
encoder._transformer_layers[layer_num]._output_dense.set_weights([
allenai_model[f"encoder.layer.{layer_num}.output.dense.weight"].T,
allenai_model[f"encoder.layer.{layer_num}.output.dense.bias"]
])
encoder._transformer_layers[layer_num]._output_layer_norm.set_weights([
allenai_model[f"encoder.layer.{layer_num}.output.LayerNorm.weight"],
allenai_model[f"encoder.layer.{layer_num}.output.LayerNorm.bias"]
])
def convert_checkpoint(output_path):
"""Converts and save the checkpoint."""
output_dir, _ = os.path.split(output_path)
tf.io.gfile.makedirs(output_dir)
encoder = _create_longformer_model()
allenai_model = _get_pytorch_longformer_model()
sequence_length = 128
batch_size = 2
word_id_data = np.random.randint(
10, size=(batch_size, sequence_length), dtype=np.int32)
mask_data = np.random.randint(
2, size=(batch_size, sequence_length), dtype=np.int32)
type_id_data = np.random.randint(
2, size=(batch_size, sequence_length), dtype=np.int32)
inputs = {
"input_word_ids": word_id_data,
"input_mask": mask_data,
"input_type_ids": type_id_data,
}
encoder(inputs)
convert(encoder, allenai_model)
tf.train.Checkpoint(encoder=encoder).write(output_path)
def main(_):
convert_checkpoint("longformer-4096/longformer")
if __name__ == "__main__":
app.run(main)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Longformer training examples to Tfrecord."""
import collections
import os
import datasets
import tensorflow as tf
import transformers
pretrained_lm = "allenai/longformer-base-4096"
task_name = "mnli"
save_path = "./"
raw_datasets = datasets.load_dataset("glue", task_name, cache_dir=None)
label_list = raw_datasets["train"].features["label"].names
num_labels = len(label_list)
tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained_lm,
use_fast=True,
)
task_to_keys = {
"cola": ("sentence", None),
"mnli": ("premise", "hypothesis"),
"mrpc": ("sentence1", "sentence2"),
"qnli": ("question", "sentence"),
"qqp": ("question1", "question2"),
"rte": ("sentence1", "sentence2"),
"sst2": ("sentence", None),
"stsb": ("sentence1", "sentence2"),
"wnli": ("sentence1", "sentence2"),
}
sentence1_key, sentence2_key = task_to_keys[task_name]
padding = "max_length"
# make sure this is the same with model input size.
max_seq_length = 512
def preprocess_function(examples):
# Tokenize the texts
args = ((examples[sentence1_key],) if sentence2_key is None else
(examples[sentence1_key], examples[sentence2_key]))
result = tokenizer(
*args, padding=padding, max_length=max_seq_length, truncation=True)
return result
raw_datasets = raw_datasets.map(
preprocess_function,
batched=True,
desc="Running tokenizer on dataset",
)
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["validation_matched" if task_name ==
"mnli" else "validation"]
print("train_dataset", train_dataset[0])
print("eval_dataset", eval_dataset[0])
def file_based_convert_examples_to_features(examples, output_file):
"""Convert a set of `InputExample`s to a TFRecord file."""
tf.io.gfile.makedirs(os.path.dirname(output_file))
writer = tf.io.TFRecordWriter(output_file)
for ex_index, example in enumerate(examples):
if ex_index % 10000 == 0:
print(f"Writing example {ex_index} of {len(examples)}")
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
features = collections.OrderedDict()
features["input_ids"] = create_int_feature(example["input_ids"])
features["input_mask"] = create_int_feature(example["attention_mask"])
features["segment_ids"] = create_int_feature([0] *
len(example["attention_mask"]))
features["label_ids"] = create_int_feature([example["label"]])
features["is_real_example"] = create_int_feature([1])
features["example_id"] = create_int_feature([example["idx"]])
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
writer.close()
file_based_convert_examples_to_features(
train_dataset,
os.path.join(save_path,
f"{pretrained_lm.replace('/', '_')}_train.tf_record"))
file_based_convert_examples_to_features(
eval_dataset,
os.path.join(save_path,
f"{pretrained_lm.replace('/', '_')}_eval.tf_record"))
...@@ -176,8 +176,7 @@ devices. See the [TF Lite Example](#tf-lite-example) to export and run your own ...@@ -176,8 +176,7 @@ devices. See the [TF Lite Example](#tf-lite-example) to export and run your own
models. We also provide [quantized TF Lite binaries via TF Hub](https://tfhub.dev/s?deployment-format=lite&q=movinet). models. We also provide [quantized TF Lite binaries via TF Hub](https://tfhub.dev/s?deployment-format=lite&q=movinet).
For reference, MoViNet-A0-Stream runs with a similar latency to For reference, MoViNet-A0-Stream runs with a similar latency to
[MobileNetV3-Large] [MobileNetV3-Large](https://tfhub.dev/google/imagenet/mobilenet_v3_large_100_224/classification/)
(https://tfhub.dev/google/imagenet/mobilenet_v3_large_100_224/classification/)
with +5% accuracy on Kinetics 600. with +5% accuracy on Kinetics 600.
| Model Name | Input Shape | Pixel 4 Latency\* | x86 Latency\* | TF Lite Binary | | Model Name | Input Shape | Pixel 4 Latency\* | x86 Latency\* | TF Lite Binary |
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Contains definitions of Mobile Video Networks. """Contains definitions of Mobile Video Networks.
Reference: https://arxiv.org/pdf/2103.11511.pdf Reference: https://arxiv.org/pdf/2103.11511.pdf
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Contains common building blocks for MoViNets. """Contains common building blocks for MoViNets.
Reference: https://arxiv.org/pdf/2103.11511.pdf Reference: https://arxiv.org/pdf/2103.11511.pdf
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Tests for movinet_layers.py.""" """Tests for movinet_layers.py."""
from absl.testing import parameterized from absl.testing import parameterized
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Tests for movinet_model.py.""" """Tests for movinet_model.py."""
from absl.testing import parameterized from absl.testing import parameterized
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Tests for movinet.py.""" """Tests for movinet.py."""
from absl.testing import parameterized from absl.testing import parameterized
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
r"""Exports models to tf.saved_model. r"""Exports models to tf.saved_model.
Export example: Export example:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment