"...en/_static/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "91f8cbe25a0896168a5a9ae2245144b02d588083"
Commit 608a8f5b authored by thomwolf's avatar thomwolf
Browse files

updating tf 2.0 layer_norm to T5 layer norm

parent 8e651f56
......@@ -17,16 +17,11 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging
import math
import os
import sys
import copy
import itertools
from io import open
import numpy as np
import tensorflow as tf
from .configuration_t5 import T5Config
......@@ -45,6 +40,28 @@ TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP = {
# - TFPreTrainedModel for the models (it-self a sub-class of tf.keras.Model)
####################################################
class TFT5LayerNorm(tf.keras.layers.Layer):
def __init__(self, epsilon=1e-6, **kwargs):
""" Construct a layernorm module in the T5 style
No bias and no substraction of mean.
"""
super(TFT5LayerNorm, self).__init__(**kwargs)
self.variance_epsilon = epsilon
def build(self, input_shape):
"""Build shared word embedding layer """
self.weight = self.add_weight(
"weight",
shape=(input_shape[-1],),
initializer='ones')
super(TFT5LayerNorm, self).build(input_shape)
def call(self, x):
variance = tf.math.reduce_min(tf.math.square(x), axis=-1, keepdims=True)
x = x * tf.math.rsqrt(variance + self.variance_epsilon)
return self.weight * x
class TFT5DenseReluDense(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFT5DenseReluDense, self).__init__(**kwargs)
......@@ -65,8 +82,8 @@ class TFT5LayerFF(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFT5LayerFF, self).__init__(**kwargs)
self.DenseReluDense = TFT5DenseReluDense(config, name='DenseReluDense')
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon,
name='layer_norm')
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon,
name='layer_norm')
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def call(self, hidden_states, training=False):
......@@ -249,8 +266,8 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
self.SelfAttention = TFT5Attention(config,
has_relative_attention_bias=has_relative_attention_bias,
name='SelfAttention')
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon,
name='layer_norm')
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon,
name='layer_norm')
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def call(self, hidden_states, attention_mask=None, position_bias=None,
......@@ -273,8 +290,8 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
self.EncDecAttention = TFT5Attention(config,
has_relative_attention_bias=has_relative_attention_bias,
name='EncDecAttention')
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon,
name='layer_norm')
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon,
name='layer_norm')
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def call(self, hidden_states, kv, attention_mask=None, position_bias=None,
......@@ -353,8 +370,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
has_relative_attention_bias=bool(i == 0),
name='block_._{}'.format(i))
for i in range(config.num_layers)]
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon,
name='final_layer_norm')
self.final_layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon,
name='final_layer_norm')
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def _resize_token_embeddings(self, new_num_tokens):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment