teams.py 3.92 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""TEAMS model configurations and instantiation methods."""
import dataclasses

import gin
import tensorflow as tf

from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.nlp.configs import encoders
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
24
from official.nlp.modeling import layers
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
25
26
27
28
29
30
from official.nlp.modeling import networks


@dataclasses.dataclass
class TeamsPretrainerConfig(base_config.Config):
  """Teams pretrainer configuration."""
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
31
32
33
34
35
36
37
38
  # Candidate size for multi-word selection task, including the correct word.
  candidate_size: int = 5
  # Weight for the generator masked language model task.
  generator_loss_weight: float = 1.0
  # Weight for the replaced token detection task.
  discriminator_rtd_loss_weight: float = 5.0
  # Weight for the multi-word selection task.
  discriminator_mws_loss_weight: float = 2.0
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
39
40
41
  # Whether share embedding network between generator and discriminator.
  tie_embeddings: bool = True
  # Number of bottom layers shared between generator and discriminator.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
42
  # Non-positive value implies no sharing.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
43
44
45
46
47
  num_shared_generator_hidden_layers: int = 3
  # Number of bottom layers shared between different discriminator tasks.
  num_discriminator_task_agnostic_layers: int = 11
  generator: encoders.BertEncoderConfig = encoders.BertEncoderConfig()
  discriminator: encoders.BertEncoderConfig = encoders.BertEncoderConfig()
Hongkun Yu's avatar
Hongkun Yu committed
48
49
50
51


class TeamsEncoderConfig(encoders.BertEncoderConfig):
  pass
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
52
53
54


@gin.configurable
Hongkun Yu's avatar
Hongkun Yu committed
55
56
57
58
@base_config.bind(TeamsEncoderConfig)
def get_encoder(bert_config: TeamsEncoderConfig,
                embedding_network=None,
                hidden_layers=None):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
59
60
61
62
  """Gets a 'EncoderScaffold' object.

  Args:
    bert_config: A 'modeling.BertConfig'.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
63
64
    embedding_network: Embedding network instance.
    hidden_layers: List of hidden layer instances.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
65
66
67
68

  Returns:
    A encoder object.
  """
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
69
70
71
72
73
74
75
76
77
78
  embedding_cfg = dict(
      vocab_size=bert_config.vocab_size,
      type_vocab_size=bert_config.type_vocab_size,
      hidden_size=bert_config.hidden_size,
      embedding_width=bert_config.embedding_size,
      max_seq_length=bert_config.max_position_embeddings,
      initializer=tf.keras.initializers.TruncatedNormal(
          stddev=bert_config.initializer_range),
      dropout_rate=bert_config.dropout_rate,
  )
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
79
80
81
  hidden_cfg = dict(
      num_attention_heads=bert_config.num_attention_heads,
      intermediate_size=bert_config.intermediate_size,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
82
83
84
85
      intermediate_activation=tf_utils.get_activation(
          bert_config.hidden_activation),
      dropout_rate=bert_config.dropout_rate,
      attention_dropout_rate=bert_config.attention_dropout_rate,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
86
87
88
      kernel_initializer=tf.keras.initializers.TruncatedNormal(
          stddev=bert_config.initializer_range),
  )
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
89
  if embedding_network is None:
Jialu Liu's avatar
Jialu Liu committed
90
91
92
    embedding_network = networks.PackedSequenceEmbedding
  if hidden_layers is None:
    hidden_layers = layers.Transformer
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
93
94
  kwargs = dict(
      embedding_cfg=embedding_cfg,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
95
96
      embedding_cls=embedding_network,
      hidden_cls=hidden_layers,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
97
      hidden_cfg=hidden_cfg,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
98
      num_hidden_instances=bert_config.num_layers,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
99
100
      pooled_output_dim=bert_config.hidden_size,
      pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
101
102
          stddev=bert_config.initializer_range),
      dict_outputs=True)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
103
104

  # Relies on gin configuration to define the Transformer encoder arguments.
Hongkun Yu's avatar
Hongkun Yu committed
105
  return networks.EncoderScaffold(**kwargs)