teams.py 3.58 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""TEAMS model configurations and instantiation methods."""
import dataclasses

import gin
import tensorflow as tf

from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.nlp.configs import encoders
from official.nlp.modeling import networks


@dataclasses.dataclass
class TeamsPretrainerConfig(base_config.Config):
  """Teams pretrainer configuration."""
  num_masked_tokens: int = 76
  sequence_length: int = 512
  num_classes: int = 2
  discriminator_loss_weight: float = 50.0
  # Whether share embedding network between generator and discriminator.
  tie_embeddings: bool = True
  # Number of bottom layers shared between generator and discriminator.
  num_shared_generator_hidden_layers: int = 3
  # Number of bottom layers shared between different discriminator tasks.
  num_discriminator_task_agnostic_layers: int = 11
  disallow_correct: bool = False
  generator: encoders.BertEncoderConfig = encoders.BertEncoderConfig()
  discriminator: encoders.BertEncoderConfig = encoders.BertEncoderConfig()


@gin.configurable
def get_encoder(bert_config,
                encoder_scaffold_cls,
                embedding_inst=None,
                hidden_inst=None):
  """Gets a 'EncoderScaffold' object.

  Args:
    bert_config: A 'modeling.BertConfig'.
    encoder_scaffold_cls: An EncoderScaffold class.
    embedding_inst: Embedding instance.
    hidden_inst: List of hidden layer instances.

  Returns:
    A encoder object.
  """
  if embedding_inst is not None:
    # TODO(hongkuny): evaluate if it is better to put cfg definition in gin.
    embedding_cfg = dict(
        vocab_size=bert_config.vocab_size,
        type_vocab_size=bert_config.type_vocab_size,
        hidden_size=bert_config.hidden_size,
        embedding_width=bert_config.embedding_size,
        max_seq_length=bert_config.max_position_embeddings,
        initializer=tf.keras.initializers.TruncatedNormal(
            stddev=bert_config.initializer_range),
        dropout_rate=bert_config.hidden_dropout_prob,
    )
    embedding_inst = networks.PackedSequenceEmbedding(**embedding_cfg)
  hidden_cfg = dict(
      num_attention_heads=bert_config.num_attention_heads,
      intermediate_size=bert_config.intermediate_size,
      intermediate_activation=tf_utils.get_activation(bert_config.hidden_act),
      dropout_rate=bert_config.hidden_dropout_prob,
      attention_dropout_rate=bert_config.attention_probs_dropout_prob,
      kernel_initializer=tf.keras.initializers.TruncatedNormal(
          stddev=bert_config.initializer_range),
  )
  kwargs = dict(
      embedding_cfg=embedding_cfg,
      embedding_cls=embedding_inst,
      hidden_cls=hidden_inst,
      hidden_cfg=hidden_cfg,
      num_hidden_instances=bert_config.num_hidden_layers,
      pooled_output_dim=bert_config.hidden_size,
      pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
          stddev=bert_config.initializer_range))

  # Relies on gin configuration to define the Transformer encoder arguments.
  return encoder_scaffold_cls(**kwargs)