"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "321d70a7a9dcb8a43fbbdddbf6659fa093460cae"
Commit fafd4c86 authored by thomwolf's avatar thomwolf
Browse files

fix TF 2.0 version of T5 - update conversion script

parent 67a8be8e
...@@ -120,24 +120,21 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file ...@@ -120,24 +120,21 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path) tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
if compare_with_pt_model: if compare_with_pt_model:
inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network
tf_inputs = tf_model.dummy_inputs
tfo = tf_model(tf_inputs, training=False) # build the network
state_dict = torch.load(pytorch_checkpoint_path, map_location='cpu') state_dict = torch.load(pytorch_checkpoint_path, map_location='cpu')
pt_model = pt_model_class.from_pretrained(pretrained_model_name_or_path=None, pt_model = pt_model_class.from_pretrained(pretrained_model_name_or_path=None,
config=config, config=config,
state_dict=state_dict) state_dict=state_dict)
pt_inputs = torch.tensor(inputs_list)
with torch.no_grad(): with torch.no_grad():
pto = pt_model(pt_inputs) pto = pt_model(**pt_model.dummy_inputs)
np_pt = pto[0].detach().numpy() np_pt = pto[0].numpy()
np_tf = tfo[0].numpy() np_tf = tfo[0].numpy()
diff = np.amax(np.abs(np_pt - np_tf)) diff = np.amax(np.abs(np_pt - np_tf))
print("Max absolute difference between models outputs {}".format(diff)) print("Max absolute difference between models outputs {}".format(diff))
assert diff <= 2e-2, "Error, model absolute difference is >2e-2" assert diff <= 2e-2, "Error, model absolute difference is >2e-2: {}".format(diff)
# Save pytorch-model # Save pytorch-model
print("Save TensorFlow model to {}".format(tf_dump_path)) print("Save TensorFlow model to {}".format(tf_dump_path))
......
...@@ -73,6 +73,9 @@ TF2_WEIGHTS_NAME = 'tf_model.h5' ...@@ -73,6 +73,9 @@ TF2_WEIGHTS_NAME = 'tf_model.h5'
TF_WEIGHTS_NAME = 'model.ckpt' TF_WEIGHTS_NAME = 'model.ckpt'
CONFIG_NAME = "config.json" CONFIG_NAME = "config.json"
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
def is_torch_available(): def is_torch_available():
return _torch_available return _torch_available
......
...@@ -32,7 +32,7 @@ from torch.nn import CrossEntropyLoss, MSELoss ...@@ -32,7 +32,7 @@ from torch.nn import CrossEntropyLoss, MSELoss
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
from .configuration_t5 import T5Config from .configuration_t5 import T5Config
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings, DUMMY_INPUTS, DUMMY_MASK
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -451,6 +451,15 @@ class T5PreTrainedModel(PreTrainedModel): ...@@ -451,6 +451,15 @@ class T5PreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_t5 load_tf_weights = load_tf_weights_in_t5
base_model_prefix = "transformer" base_model_prefix = "transformer"
@property
def dummy_inputs(self):
input_ids = torch.tensor(DUMMY_INPUTS)
input_mask = torch.tensor(DUMMY_MASK)
dummy_inputs = {'decoder_input_ids': input_ids,
'encoder_input_ids': input_ids,
'decoder_attention_mask': input_mask}
return dummy_inputs
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
factor = self.config.initializer_factor # Used for testing weights initialization factor = self.config.initializer_factor # Used for testing weights initialization
...@@ -534,9 +543,10 @@ class T5Stack(T5PreTrainedModel): ...@@ -534,9 +543,10 @@ class T5Stack(T5PreTrainedModel):
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
# T5 has a mask that can compare sequence ids, we simulate this here with this transposistion # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
extended_attention_mask = (extended_attention_mask == extended_attention_mask.transpose(-1, -2)) # extended_attention_mask = (extended_attention_mask == extended_attention_mask.transpose(-1, -2))
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -1e9 extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
...@@ -548,6 +558,10 @@ class T5Stack(T5PreTrainedModel): ...@@ -548,6 +558,10 @@ class T5Stack(T5PreTrainedModel):
if encoder_attention_mask.dim() == 2: if encoder_attention_mask.dim() == 2:
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
# encoder_extended_attention_mask = (encoder_extended_attention_mask == encoder_extended_attention_mask.transpose(-1, -2))
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9 encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
else: else:
...@@ -590,6 +604,7 @@ class T5Stack(T5PreTrainedModel): ...@@ -590,6 +604,7 @@ class T5Stack(T5PreTrainedModel):
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if i == 0: if i == 0:
# We share the position biases between the layers - the first layer store them # We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
position_bias = layer_outputs[2 if self.output_attentions else 1] position_bias = layer_outputs[2 if self.output_attentions else 1]
if self.is_decoder: if self.is_decoder:
encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 2] encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 2]
......
...@@ -26,7 +26,7 @@ import tensorflow as tf ...@@ -26,7 +26,7 @@ import tensorflow as tf
from .configuration_t5 import T5Config from .configuration_t5 import T5Config
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings, DUMMY_INPUTS, DUMMY_MASK
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -61,7 +61,7 @@ class TFT5LayerNorm(tf.keras.layers.Layer): ...@@ -61,7 +61,7 @@ class TFT5LayerNorm(tf.keras.layers.Layer):
super(TFT5LayerNorm, self).build(input_shape) super(TFT5LayerNorm, self).build(input_shape)
def call(self, x): def call(self, x):
variance = tf.math.reduce_min(tf.math.square(x), axis=-1, keepdims=True) variance = tf.math.reduce_mean(tf.math.square(x), axis=-1, keepdims=True)
x = x * tf.math.rsqrt(variance + self.variance_epsilon) x = x * tf.math.rsqrt(variance + self.variance_epsilon)
return self.weight * x return self.weight * x
...@@ -231,19 +231,19 @@ class TFT5Attention(tf.keras.layers.Layer): ...@@ -231,19 +231,19 @@ class TFT5Attention(tf.keras.layers.Layer):
cache[self.layer_id] = (k, v) cache[self.layer_id] = (k, v)
# q = q / math.sqrt(dim_per_head) # No scaling in T5 # q = q / math.sqrt(dim_per_head) # No scaling in T5
scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen) # scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen)
scores = tf.einsum('bnqd,bnkd->bnqk', q, k) # (bs, n_heads, qlen, klen)
if position_bias is None: if position_bias is None:
if not self.has_relative_attention_bias: if not self.has_relative_attention_bias:
raise ValueError("No position_bias provided and no weights to compute position_bias") raise ValueError("No position_bias provided and no weights to compute position_bias")
position_bias = self.compute_bias(qlen, klen) position_bias = self.compute_bias(qlen, klen)
scores += position_bias if mask is not None:
position_bias = position_bias + mask
if mask is not None: # mask = (mask == 0).expand_as(scores) # (bs, n_heads, qlen, klen)
scores += mask # scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
# mask = (mask == 0).expand_as(scores) # (bs, n_heads, qlen, klen)
# scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
scores += position_bias
weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen) weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen) weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)
...@@ -350,11 +350,11 @@ class TFT5Block(tf.keras.layers.Layer): ...@@ -350,11 +350,11 @@ class TFT5Block(tf.keras.layers.Layer):
head_mask=head_mask, head_mask=head_mask,
training=training) training=training)
hidden_states = cross_attention_outputs[0] hidden_states = cross_attention_outputs[0]
outputs = cross_attention_outputs[1:] + outputs outputs = outputs + cross_attention_outputs[1:]
hidden_states = self.layer[2](hidden_states, training=training) hidden_states = self.layer[2](hidden_states, training=training)
outputs = (hidden_states,) + outputs # add attentions if we output them outputs = (hidden_states,) + outputs # add attentions if we output them
return outputs return outputs # hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
#################################################### ####################################################
...@@ -418,7 +418,13 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -418,7 +418,13 @@ class TFT5MainLayer(tf.keras.layers.Layer):
# positions we want to attend and -10000.0 for masked positions. # positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposistion
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
# extended_attention_mask = tf.math.equal(extended_attention_mask,
# tf.transpose(extended_attention_mask, perm=(-1, -2)))
extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
if self.is_decoder: if self.is_decoder:
# If a 2D ou 3D attention mask is provided for the cross-attention # If a 2D ou 3D attention mask is provided for the cross-attention
...@@ -430,7 +436,12 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -430,7 +436,12 @@ class TFT5MainLayer(tf.keras.layers.Layer):
if num_dims_encoder_attention_mask == 2: if num_dims_encoder_attention_mask == 2:
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 # T5 has a mask that can compare sequence ids, we can simulate this here with this transposistion
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
# encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
# tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
else: else:
encoder_extended_attention_mask = None encoder_extended_attention_mask = None
...@@ -463,6 +474,8 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -463,6 +474,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
training=training) training=training)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if i == 0: if i == 0:
# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
position_bias = layer_outputs[2 if self.output_attentions else 1] position_bias = layer_outputs[2 if self.output_attentions else 1]
if self.is_decoder: if self.is_decoder:
encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 2] encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 2]
...@@ -502,8 +515,8 @@ class TFT5PreTrainedModel(TFPreTrainedModel): ...@@ -502,8 +515,8 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
@property @property
def dummy_inputs(self): def dummy_inputs(self):
input_ids = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]) input_ids = tf.constant(DUMMY_INPUTS)
input_mask = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]) input_mask = tf.constant(DUMMY_MASK)
dummy_inputs = {'decoder_input_ids': input_ids, dummy_inputs = {'decoder_input_ids': input_ids,
'encoder_input_ids': input_ids, 'encoder_input_ids': input_ids,
'decoder_attention_mask': input_mask} 'decoder_attention_mask': input_mask}
......
...@@ -24,13 +24,11 @@ import os ...@@ -24,13 +24,11 @@ import os
import tensorflow as tf import tensorflow as tf
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME, DUMMY_INPUTS
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
class TFPreTrainedModel(tf.keras.Model): class TFPreTrainedModel(tf.keras.Model):
r""" Base class for all TF models. r""" Base class for all TF models.
...@@ -59,7 +57,7 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -59,7 +57,7 @@ class TFPreTrainedModel(tf.keras.Model):
Returns: Returns:
tf.Tensor with dummy inputs tf.Tensor with dummy inputs
""" """
return tf.constant(DUMMY_INPUTS) return {'input_ids': tf.constant(DUMMY_INPUTS)}
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super(TFPreTrainedModel, self).__init__(*inputs, **kwargs) super(TFPreTrainedModel, self).__init__(*inputs, **kwargs)
......
...@@ -31,11 +31,10 @@ from torch.nn import CrossEntropyLoss ...@@ -31,11 +31,10 @@ from torch.nn import CrossEntropyLoss
from torch.nn import functional as F from torch.nn import functional as F
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME, DUMMY_INPUTS
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: try:
from torch.nn import Identity from torch.nn import Identity
except ImportError: except ImportError:
...@@ -71,6 +70,15 @@ class PreTrainedModel(nn.Module): ...@@ -71,6 +70,15 @@ class PreTrainedModel(nn.Module):
load_tf_weights = lambda model, config, path: None load_tf_weights = lambda model, config, path: None
base_model_prefix = "" base_model_prefix = ""
@property
def dummy_inputs(self):
""" Dummy inputs to do a forward pass in the network.
Returns:
torch.Tensor with dummy inputs
"""
return {'input_ids': torch.tensor(DUMMY_INPUTS)}
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super(PreTrainedModel, self).__init__() super(PreTrainedModel, self).__init__()
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment