electra.py 3.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ELECTRA model configurations and instantiation methods."""
from typing import List, Optional

import dataclasses
import tensorflow as tf

from official.modeling import tf_utils
from official.modeling.hyperparams import base_config
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.modeling import layers
from official.nlp.modeling.models import electra_pretrainer


@dataclasses.dataclass
class ELECTRAPretrainerConfig(base_config.Config):
  """ELECTRA pretrainer configuration."""
  num_masked_tokens: int = 76
  sequence_length: int = 512
  num_classes: int = 2
  discriminator_loss_weight: float = 50.0
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
37
38
  tie_embeddings: bool = True
  disallow_correct: bool = False
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
  generator_encoder: encoders.TransformerEncoderConfig = (
      encoders.TransformerEncoderConfig())
  discriminator_encoder: encoders.TransformerEncoderConfig = (
      encoders.TransformerEncoderConfig())
  cls_heads: List[bert.ClsHeadConfig] = dataclasses.field(default_factory=list)


def instantiate_classification_heads_from_cfgs(
    cls_head_configs: List[bert.ClsHeadConfig]
) -> List[layers.ClassificationHead]:
  if cls_head_configs:
    return [
        layers.ClassificationHead(**cfg.as_dict()) for cfg in cls_head_configs
    ]
  else:
    return []


def instantiate_pretrainer_from_cfg(
    config: ELECTRAPretrainerConfig,
    generator_network: Optional[tf.keras.Model] = None,
    discriminator_network: Optional[tf.keras.Model] = None,
    ) -> electra_pretrainer.ElectraPretrainer:
  """Instantiates ElectraPretrainer from the config."""
  generator_encoder_cfg = config.generator_encoder
  discriminator_encoder_cfg = config.discriminator_encoder
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
65
  # Copy discriminator's embeddings to generator for easier model serialization.
66
67
68
  if discriminator_network is None:
    discriminator_network = encoders.instantiate_encoder_from_cfg(
        discriminator_encoder_cfg)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
69
70
71
72
73
74
75
76
77
  if generator_network is None:
    if config.tie_embeddings:
      embedding_layer = discriminator_network.get_embedding_layer()
      generator_network = encoders.instantiate_encoder_from_cfg(
          generator_encoder_cfg, embedding_layer=embedding_layer)
    else:
      generator_network = encoders.instantiate_encoder_from_cfg(
          generator_encoder_cfg)

78
79
80
81
82
83
84
85
86
87
88
89
  return electra_pretrainer.ElectraPretrainer(
      generator_network=generator_network,
      discriminator_network=discriminator_network,
      vocab_size=config.generator_encoder.vocab_size,
      num_classes=config.num_classes,
      sequence_length=config.sequence_length,
      num_token_predictions=config.num_masked_tokens,
      mlm_activation=tf_utils.get_activation(
          generator_encoder_cfg.hidden_activation),
      mlm_initializer=tf.keras.initializers.TruncatedNormal(
          stddev=generator_encoder_cfg.initializer_range),
      classification_heads=instantiate_classification_heads_from_cfgs(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
90
91
          config.cls_heads),
      disallow_correct=config.disallow_correct)