transformer.py 2.93 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines the Transformer model in TF 2.0.

Model paper: https://arxiv.org/pdf/1706.03762.pdf
Transformer model code source: https://github.com/tensorflow/tensor2tensor
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
xinliupitt's avatar
xinliupitt committed
25
26
27
from official.modeling import tf_utils
from official.modeling.activations import attention_initializer
from official.nlp.modeling import layers
28
from official.nlp.modeling.layers import position_embedding
xinliupitt's avatar
xinliupitt committed
29
30
from official.nlp.modeling.layers import transformer
from official.nlp.modeling.models import seq2seq_transformer
31
from official.nlp.modeling.ops import beam_search
32
from official.nlp.transformer import metrics
33
from official.nlp.transformer import model_utils
34
from official.nlp.transformer.utils.tokenizer import EOS_ID
35
36


Reed's avatar
Reed committed
37
38
39
40
41
# Disable the not-callable lint error, since it claims many objects are not
# callable when they actually are.
# pylint: disable=not-callable


42
43
44
45
46
47
def create_model(params, is_train):
  """Creates transformer model."""
  with tf.name_scope("model"):
    if is_train:
      inputs = tf.keras.layers.Input((None,), dtype="int64", name="inputs")
      targets = tf.keras.layers.Input((None,), dtype="int64", name="targets")
xinliupitt's avatar
xinliupitt committed
48
49
      internal_model = seq2seq_transformer.Seq2SeqTransformer(
          params, name="transformer_v2")
50
51
52
      logits = internal_model([inputs, targets], training=is_train)
      vocab_size = params["vocab_size"]
      label_smoothing = params["label_smoothing"]
53
54
      if params["enable_metrics_in_training"]:
        logits = metrics.MetricLayer(vocab_size)([logits, targets])
55
      logits = tf.keras.layers.Lambda(lambda x: x, name="logits",
Reed's avatar
Reed committed
56
                                      dtype=tf.float32)(logits)
guptapriya's avatar
guptapriya committed
57
58
59
60
61
      model = tf.keras.Model([inputs, targets], logits)
      loss = metrics.transformer_loss(
          logits, targets, label_smoothing, vocab_size)
      model.add_loss(loss)
      return model
62
63
64

    else:
      inputs = tf.keras.layers.Input((None,), dtype="int64", name="inputs")
xinliupitt's avatar
xinliupitt committed
65
66
      internal_model = seq2seq_transformer.Seq2SeqTransformer(
          params, name="transformer_v2")
67
68
69
      ret = internal_model([inputs], training=is_train)
      outputs, scores = ret["outputs"], ret["scores"]
      return tf.keras.Model(inputs, [outputs, scores])