"configs/datasets/vscode:/vscode.git/clone" did not exist on "fe0b71703316d82c14888c6d4f81d8db5dc4b225"
Commit fafd4c86 authored by thomwolf's avatar thomwolf
Browse files

fix TF 2.0 version of T5 - update conversion script

parent 67a8be8e
......@@ -120,24 +120,21 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
if compare_with_pt_model:
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
tf_inputs = tf_model.dummy_inputs
tfo = tf_model(tf_inputs, training=False) # build the network
tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network
state_dict = torch.load(pytorch_checkpoint_path, map_location='cpu')
pt_model = pt_model_class.from_pretrained(pretrained_model_name_or_path=None,
config=config,
state_dict=state_dict)
pt_inputs = torch.tensor(inputs_list)
with torch.no_grad():
pto = pt_model(pt_inputs)
pto = pt_model(**pt_model.dummy_inputs)
np_pt = pto[0].detach().numpy()
np_pt = pto[0].numpy()
np_tf = tfo[0].numpy()
diff = np.amax(np.abs(np_pt - np_tf))
print("Max absolute difference between models outputs {}".format(diff))
assert diff <= 2e-2, "Error, model absolute difference is >2e-2"
assert diff <= 2e-2, "Error, model absolute difference is >2e-2: {}".format(diff)
# Save pytorch-model
print("Save TensorFlow model to {}".format(tf_dump_path))
......
......@@ -73,6 +73,9 @@ TF2_WEIGHTS_NAME = 'tf_model.h5'
TF_WEIGHTS_NAME = 'model.ckpt'
CONFIG_NAME = "config.json"
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
def is_torch_available():
return _torch_available
......
......@@ -32,7 +32,7 @@ from torch.nn import CrossEntropyLoss, MSELoss
from .modeling_utils import PreTrainedModel
from .configuration_t5 import T5Config
from .file_utils import add_start_docstrings
from .file_utils import add_start_docstrings, DUMMY_INPUTS, DUMMY_MASK
logger = logging.getLogger(__name__)
......@@ -451,6 +451,15 @@ class T5PreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_t5
base_model_prefix = "transformer"
@property
def dummy_inputs(self):
input_ids = torch.tensor(DUMMY_INPUTS)
input_mask = torch.tensor(DUMMY_MASK)
dummy_inputs = {'decoder_input_ids': input_ids,
'encoder_input_ids': input_ids,
'decoder_attention_mask': input_mask}
return dummy_inputs
def _init_weights(self, module):
""" Initialize the weights """
factor = self.config.initializer_factor # Used for testing weights initialization
......@@ -534,9 +543,10 @@ class T5Stack(T5PreTrainedModel):
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
# T5 has a mask that can compare sequence ids, we simulate this here with this transposistion
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
extended_attention_mask = (extended_attention_mask == extended_attention_mask.transpose(-1, -2))
# extended_attention_mask = (extended_attention_mask == extended_attention_mask.transpose(-1, -2))
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
......@@ -548,6 +558,10 @@ class T5Stack(T5PreTrainedModel):
if encoder_attention_mask.dim() == 2:
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
# encoder_extended_attention_mask = (encoder_extended_attention_mask == encoder_extended_attention_mask.transpose(-1, -2))
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
else:
......@@ -590,6 +604,7 @@ class T5Stack(T5PreTrainedModel):
hidden_states = layer_outputs[0]
if i == 0:
# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
position_bias = layer_outputs[2 if self.output_attentions else 1]
if self.is_decoder:
encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 2]
......
......@@ -26,7 +26,7 @@ import tensorflow as tf
from .configuration_t5 import T5Config
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list
from .file_utils import add_start_docstrings
from .file_utils import add_start_docstrings, DUMMY_INPUTS, DUMMY_MASK
logger = logging.getLogger(__name__)
......@@ -61,7 +61,7 @@ class TFT5LayerNorm(tf.keras.layers.Layer):
super(TFT5LayerNorm, self).build(input_shape)
def call(self, x):
variance = tf.math.reduce_min(tf.math.square(x), axis=-1, keepdims=True)
variance = tf.math.reduce_mean(tf.math.square(x), axis=-1, keepdims=True)
x = x * tf.math.rsqrt(variance + self.variance_epsilon)
return self.weight * x
......@@ -231,19 +231,19 @@ class TFT5Attention(tf.keras.layers.Layer):
cache[self.layer_id] = (k, v)
# q = q / math.sqrt(dim_per_head) # No scaling in T5
scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen)
# scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen)
scores = tf.einsum('bnqd,bnkd->bnqk', q, k) # (bs, n_heads, qlen, klen)
if position_bias is None:
if not self.has_relative_attention_bias:
raise ValueError("No position_bias provided and no weights to compute position_bias")
position_bias = self.compute_bias(qlen, klen)
scores += position_bias
if mask is not None:
scores += mask
position_bias = position_bias + mask
# mask = (mask == 0).expand_as(scores) # (bs, n_heads, qlen, klen)
# scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
scores += position_bias
weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)
......@@ -350,11 +350,11 @@ class TFT5Block(tf.keras.layers.Layer):
head_mask=head_mask,
training=training)
hidden_states = cross_attention_outputs[0]
outputs = cross_attention_outputs[1:] + outputs
outputs = outputs + cross_attention_outputs[1:]
hidden_states = self.layer[2](hidden_states, training=training)
outputs = (hidden_states,) + outputs # add attentions if we output them
return outputs
return outputs # hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
####################################################
......@@ -418,7 +418,13 @@ class TFT5MainLayer(tf.keras.layers.Layer):
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposistion
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
# extended_attention_mask = tf.math.equal(extended_attention_mask,
# tf.transpose(extended_attention_mask, perm=(-1, -2)))
extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
if self.is_decoder:
# If a 2D ou 3D attention mask is provided for the cross-attention
......@@ -430,7 +436,12 @@ class TFT5MainLayer(tf.keras.layers.Layer):
if num_dims_encoder_attention_mask == 2:
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposistion
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
# encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
# tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
else:
encoder_extended_attention_mask = None
......@@ -463,6 +474,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
training=training)
hidden_states = layer_outputs[0]
if i == 0:
# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
position_bias = layer_outputs[2 if self.output_attentions else 1]
if self.is_decoder:
encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 2]
......@@ -502,8 +515,8 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
@property
def dummy_inputs(self):
input_ids = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
input_mask = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
input_ids = tf.constant(DUMMY_INPUTS)
input_mask = tf.constant(DUMMY_MASK)
dummy_inputs = {'decoder_input_ids': input_ids,
'encoder_input_ids': input_ids,
'decoder_attention_mask': input_mask}
......
......@@ -24,13 +24,11 @@ import os
import tensorflow as tf
from .configuration_utils import PretrainedConfig
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME, DUMMY_INPUTS
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__)
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
class TFPreTrainedModel(tf.keras.Model):
r""" Base class for all TF models.
......@@ -59,7 +57,7 @@ class TFPreTrainedModel(tf.keras.Model):
Returns:
tf.Tensor with dummy inputs
"""
return tf.constant(DUMMY_INPUTS)
return {'input_ids': tf.constant(DUMMY_INPUTS)}
def __init__(self, config, *inputs, **kwargs):
super(TFPreTrainedModel, self).__init__(*inputs, **kwargs)
......
......@@ -31,11 +31,10 @@ from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from .configuration_utils import PretrainedConfig
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME, DUMMY_INPUTS
logger = logging.getLogger(__name__)
try:
from torch.nn import Identity
except ImportError:
......@@ -71,6 +70,15 @@ class PreTrainedModel(nn.Module):
load_tf_weights = lambda model, config, path: None
base_model_prefix = ""
@property
def dummy_inputs(self):
""" Dummy inputs to do a forward pass in the network.
Returns:
torch.Tensor with dummy inputs
"""
return {'input_ids': torch.tensor(DUMMY_INPUTS)}
def __init__(self, config, *inputs, **kwargs):
super(PreTrainedModel, self).__init__()
if not isinstance(config, PretrainedConfig):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment