Commit 608a8f5b authored by thomwolf's avatar thomwolf
Browse files

updating tf 2.0 layer_norm to T5 layer norm

parent 8e651f56
...@@ -17,16 +17,11 @@ ...@@ -17,16 +17,11 @@
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging import logging
import math import math
import os
import sys
import copy import copy
import itertools import itertools
from io import open
import numpy as np
import tensorflow as tf import tensorflow as tf
from .configuration_t5 import T5Config from .configuration_t5 import T5Config
...@@ -45,6 +40,28 @@ TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP = { ...@@ -45,6 +40,28 @@ TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP = {
# - TFPreTrainedModel for the models (it-self a sub-class of tf.keras.Model) # - TFPreTrainedModel for the models (it-self a sub-class of tf.keras.Model)
#################################################### ####################################################
class TFT5LayerNorm(tf.keras.layers.Layer):
def __init__(self, epsilon=1e-6, **kwargs):
""" Construct a layernorm module in the T5 style
No bias and no substraction of mean.
"""
super(TFT5LayerNorm, self).__init__(**kwargs)
self.variance_epsilon = epsilon
def build(self, input_shape):
"""Build shared word embedding layer """
self.weight = self.add_weight(
"weight",
shape=(input_shape[-1],),
initializer='ones')
super(TFT5LayerNorm, self).build(input_shape)
def call(self, x):
variance = tf.math.reduce_min(tf.math.square(x), axis=-1, keepdims=True)
x = x * tf.math.rsqrt(variance + self.variance_epsilon)
return self.weight * x
class TFT5DenseReluDense(tf.keras.layers.Layer): class TFT5DenseReluDense(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super(TFT5DenseReluDense, self).__init__(**kwargs) super(TFT5DenseReluDense, self).__init__(**kwargs)
...@@ -65,8 +82,8 @@ class TFT5LayerFF(tf.keras.layers.Layer): ...@@ -65,8 +82,8 @@ class TFT5LayerFF(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super(TFT5LayerFF, self).__init__(**kwargs) super(TFT5LayerFF, self).__init__(**kwargs)
self.DenseReluDense = TFT5DenseReluDense(config, name='DenseReluDense') self.DenseReluDense = TFT5DenseReluDense(config, name='DenseReluDense')
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon,
name='layer_norm') name='layer_norm')
self.dropout = tf.keras.layers.Dropout(config.dropout_rate) self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def call(self, hidden_states, training=False): def call(self, hidden_states, training=False):
...@@ -249,8 +266,8 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer): ...@@ -249,8 +266,8 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
self.SelfAttention = TFT5Attention(config, self.SelfAttention = TFT5Attention(config,
has_relative_attention_bias=has_relative_attention_bias, has_relative_attention_bias=has_relative_attention_bias,
name='SelfAttention') name='SelfAttention')
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon,
name='layer_norm') name='layer_norm')
self.dropout = tf.keras.layers.Dropout(config.dropout_rate) self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def call(self, hidden_states, attention_mask=None, position_bias=None, def call(self, hidden_states, attention_mask=None, position_bias=None,
...@@ -273,8 +290,8 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer): ...@@ -273,8 +290,8 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
self.EncDecAttention = TFT5Attention(config, self.EncDecAttention = TFT5Attention(config,
has_relative_attention_bias=has_relative_attention_bias, has_relative_attention_bias=has_relative_attention_bias,
name='EncDecAttention') name='EncDecAttention')
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon,
name='layer_norm') name='layer_norm')
self.dropout = tf.keras.layers.Dropout(config.dropout_rate) self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def call(self, hidden_states, kv, attention_mask=None, position_bias=None, def call(self, hidden_states, kv, attention_mask=None, position_bias=None,
...@@ -353,8 +370,8 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -353,8 +370,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
has_relative_attention_bias=bool(i == 0), has_relative_attention_bias=bool(i == 0),
name='block_._{}'.format(i)) name='block_._{}'.format(i))
for i in range(config.num_layers)] for i in range(config.num_layers)]
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, self.final_layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon,
name='final_layer_norm') name='final_layer_norm')
self.dropout = tf.keras.layers.Dropout(config.dropout_rate) self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment