transformer.py 2.52 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines the Transformer model in TF 2.0.

Model paper: https://arxiv.org/pdf/1706.03762.pdf
Transformer model code source: https://github.com/tensorflow/tensor2tensor
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
xinliupitt's avatar
xinliupitt committed
25
from official.nlp.modeling.models import seq2seq_transformer
26
from official.nlp.transformer import metrics
27
28


Reed's avatar
Reed committed
29
30
31
32
33
# Disable the not-callable lint error, since it claims many objects are not
# callable when they actually are.
# pylint: disable=not-callable


34
35
36
37
38
39
def create_model(params, is_train):
  """Creates transformer model."""
  with tf.name_scope("model"):
    if is_train:
      inputs = tf.keras.layers.Input((None,), dtype="int64", name="inputs")
      targets = tf.keras.layers.Input((None,), dtype="int64", name="targets")
xinliupitt's avatar
xinliupitt committed
40
41
      internal_model = seq2seq_transformer.Seq2SeqTransformer(
          params, name="transformer_v2")
42
43
44
      logits = internal_model([inputs, targets], training=is_train)
      vocab_size = params["vocab_size"]
      label_smoothing = params["label_smoothing"]
45
46
      if params["enable_metrics_in_training"]:
        logits = metrics.MetricLayer(vocab_size)([logits, targets])
47
      logits = tf.keras.layers.Lambda(lambda x: x, name="logits",
Reed's avatar
Reed committed
48
                                      dtype=tf.float32)(logits)
guptapriya's avatar
guptapriya committed
49
50
51
52
53
      model = tf.keras.Model([inputs, targets], logits)
      loss = metrics.transformer_loss(
          logits, targets, label_smoothing, vocab_size)
      model.add_loss(loss)
      return model
54
55
56

    else:
      inputs = tf.keras.layers.Input((None,), dtype="int64", name="inputs")
xinliupitt's avatar
xinliupitt committed
57
58
      internal_model = seq2seq_transformer.Seq2SeqTransformer(
          params, name="transformer_v2")
59
60
61
      ret = internal_model([inputs], training=is_train)
      outputs, scores = ret["outputs"], ret["scores"]
      return tf.keras.Model(inputs, [outputs, scores])