Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
fafd4c86
"configs/datasets/vscode:/vscode.git/clone" did not exist on "fe0b71703316d82c14888c6d4f81d8db5dc4b225"
Commit
fafd4c86
authored
Dec 11, 2019
by
thomwolf
Browse files
fix TF 2.0 version of T5 - update conversion script
parent
67a8be8e
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
65 additions
and
31 deletions
+65
-31
transformers/convert_pytorch_checkpoint_to_tf2.py
transformers/convert_pytorch_checkpoint_to_tf2.py
+4
-7
transformers/file_utils.py
transformers/file_utils.py
+3
-0
transformers/modeling_t5.py
transformers/modeling_t5.py
+18
-3
transformers/modeling_tf_t5.py
transformers/modeling_tf_t5.py
+28
-15
transformers/modeling_tf_utils.py
transformers/modeling_tf_utils.py
+2
-4
transformers/modeling_utils.py
transformers/modeling_utils.py
+10
-2
No files found.
transformers/convert_pytorch_checkpoint_to_tf2.py
View file @
fafd4c86
...
...
@@ -120,24 +120,21 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
tf_model
=
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
)
if
compare_with_pt_model
:
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
tf_inputs
=
tf_model
.
dummy_inputs
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# build the network
tfo
=
tf_model
(
tf_model
.
dummy_inputs
,
training
=
False
)
# build the network
state_dict
=
torch
.
load
(
pytorch_checkpoint_path
,
map_location
=
'cpu'
)
pt_model
=
pt_model_class
.
from_pretrained
(
pretrained_model_name_or_path
=
None
,
config
=
config
,
state_dict
=
state_dict
)
pt_inputs
=
torch
.
tensor
(
inputs_list
)
with
torch
.
no_grad
():
pto
=
pt_model
(
pt
_inputs
)
pto
=
pt_model
(
**
pt_model
.
dummy
_inputs
)
np_pt
=
pto
[
0
].
detach
().
numpy
()
np_pt
=
pto
[
0
].
numpy
()
np_tf
=
tfo
[
0
].
numpy
()
diff
=
np
.
amax
(
np
.
abs
(
np_pt
-
np_tf
))
print
(
"Max absolute difference between models outputs {}"
.
format
(
diff
))
assert
diff
<=
2e-2
,
"Error, model absolute difference is >2e-2
"
assert
diff
<=
2e-2
,
"Error, model absolute difference is >2e-2
: {}"
.
format
(
diff
)
# Save pytorch-model
print
(
"Save TensorFlow model to {}"
.
format
(
tf_dump_path
))
...
...
transformers/file_utils.py
View file @
fafd4c86
...
...
@@ -73,6 +73,9 @@ TF2_WEIGHTS_NAME = 'tf_model.h5'
TF_WEIGHTS_NAME
=
'model.ckpt'
CONFIG_NAME
=
"config.json"
DUMMY_INPUTS
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
DUMMY_MASK
=
[[
1
,
1
,
1
,
1
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
0
,
0
,
0
,
1
,
1
]]
def
is_torch_available
():
return
_torch_available
...
...
transformers/modeling_t5.py
View file @
fafd4c86
...
...
@@ -32,7 +32,7 @@ from torch.nn import CrossEntropyLoss, MSELoss
from
.modeling_utils
import
PreTrainedModel
from
.configuration_t5
import
T5Config
from
.file_utils
import
add_start_docstrings
from
.file_utils
import
add_start_docstrings
,
DUMMY_INPUTS
,
DUMMY_MASK
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -451,6 +451,15 @@ class T5PreTrainedModel(PreTrainedModel):
load_tf_weights
=
load_tf_weights_in_t5
base_model_prefix
=
"transformer"
@
property
def
dummy_inputs
(
self
):
input_ids
=
torch
.
tensor
(
DUMMY_INPUTS
)
input_mask
=
torch
.
tensor
(
DUMMY_MASK
)
dummy_inputs
=
{
'decoder_input_ids'
:
input_ids
,
'encoder_input_ids'
:
input_ids
,
'decoder_attention_mask'
:
input_mask
}
return
dummy_inputs
def
_init_weights
(
self
,
module
):
""" Initialize the weights """
factor
=
self
.
config
.
initializer_factor
# Used for testing weights initialization
...
...
@@ -534,9 +543,10 @@ class T5Stack(T5PreTrainedModel):
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
# T5 has a mask that can compare sequence ids, we simulate this here with this transposi
s
tion
# T5 has a mask that can compare sequence ids, we
can
simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
extended_attention_mask
=
(
extended_attention_mask
==
extended_attention_mask
.
transpose
(
-
1
,
-
2
))
# extended_attention_mask = (extended_attention_mask == extended_attention_mask.transpose(-1, -2))
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
1e9
...
...
@@ -548,6 +558,10 @@ class T5Stack(T5PreTrainedModel):
if
encoder_attention_mask
.
dim
()
==
2
:
encoder_extended_attention_mask
=
encoder_attention_mask
[:,
None
,
None
,
:]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
# encoder_extended_attention_mask = (encoder_extended_attention_mask == encoder_extended_attention_mask.transpose(-1, -2))
encoder_extended_attention_mask
=
encoder_extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
encoder_extended_attention_mask
=
(
1.0
-
encoder_extended_attention_mask
)
*
-
1e9
else
:
...
...
@@ -590,6 +604,7 @@ class T5Stack(T5PreTrainedModel):
hidden_states
=
layer_outputs
[
0
]
if
i
==
0
:
# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
position_bias
=
layer_outputs
[
2
if
self
.
output_attentions
else
1
]
if
self
.
is_decoder
:
encoder_decoder_position_bias
=
layer_outputs
[
4
if
self
.
output_attentions
else
2
]
...
...
transformers/modeling_tf_t5.py
View file @
fafd4c86
...
...
@@ -26,7 +26,7 @@ import tensorflow as tf
from
.configuration_t5
import
T5Config
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFSharedEmbeddings
,
shape_list
from
.file_utils
import
add_start_docstrings
from
.file_utils
import
add_start_docstrings
,
DUMMY_INPUTS
,
DUMMY_MASK
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -61,7 +61,7 @@ class TFT5LayerNorm(tf.keras.layers.Layer):
super
(
TFT5LayerNorm
,
self
).
build
(
input_shape
)
def
call
(
self
,
x
):
variance
=
tf
.
math
.
reduce_m
i
n
(
tf
.
math
.
square
(
x
),
axis
=-
1
,
keepdims
=
True
)
variance
=
tf
.
math
.
reduce_m
ea
n
(
tf
.
math
.
square
(
x
),
axis
=-
1
,
keepdims
=
True
)
x
=
x
*
tf
.
math
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
return
self
.
weight
*
x
...
...
@@ -231,19 +231,19 @@ class TFT5Attention(tf.keras.layers.Layer):
cache
[
self
.
layer_id
]
=
(
k
,
v
)
# q = q / math.sqrt(dim_per_head) # No scaling in T5
scores
=
tf
.
matmul
(
q
,
k
,
transpose_b
=
True
)
# (bs, n_heads, qlen, klen)
# scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen)
scores
=
tf
.
einsum
(
'bnqd,bnkd->bnqk'
,
q
,
k
)
# (bs, n_heads, qlen, klen)
if
position_bias
is
None
:
if
not
self
.
has_relative_attention_bias
:
raise
ValueError
(
"No position_bias provided and no weights to compute position_bias"
)
position_bias
=
self
.
compute_bias
(
qlen
,
klen
)
scores
+=
position_bias
if
mask
is
not
None
:
score
s
+
=
mask
position_bias
=
position_bia
s
+
mask
# mask = (mask == 0).expand_as(scores) # (bs, n_heads, qlen, klen)
# scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
scores
+=
position_bias
weights
=
tf
.
nn
.
softmax
(
scores
,
axis
=-
1
)
# (bs, n_heads, qlen, klen)
weights
=
self
.
dropout
(
weights
,
training
=
training
)
# (bs, n_heads, qlen, klen)
...
...
@@ -350,11 +350,11 @@ class TFT5Block(tf.keras.layers.Layer):
head_mask
=
head_mask
,
training
=
training
)
hidden_states
=
cross_attention_outputs
[
0
]
outputs
=
cross_attention_outputs
[
1
:]
+
outputs
outputs
=
outputs
+
cross_attention_outputs
[
1
:]
hidden_states
=
self
.
layer
[
2
](
hidden_states
,
training
=
training
)
outputs
=
(
hidden_states
,)
+
outputs
# add attentions if we output them
return
outputs
return
outputs
# hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
####################################################
...
...
@@ -418,7 +418,13 @@ class TFT5MainLayer(tf.keras.layers.Layer):
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
10000.0
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposistion
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
# extended_attention_mask = tf.math.equal(extended_attention_mask,
# tf.transpose(extended_attention_mask, perm=(-1, -2)))
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
1e9
if
self
.
is_decoder
:
# If a 2D ou 3D attention mask is provided for the cross-attention
...
...
@@ -430,7 +436,12 @@ class TFT5MainLayer(tf.keras.layers.Layer):
if
num_dims_encoder_attention_mask
==
2
:
encoder_extended_attention_mask
=
encoder_attention_mask
[:,
None
,
None
,
:]
encoder_extended_attention_mask
=
(
1.0
-
encoder_extended_attention_mask
)
*
-
10000.0
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposistion
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
# encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
# tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
encoder_extended_attention_mask
=
(
1.0
-
encoder_extended_attention_mask
)
*
-
1e9
else
:
encoder_extended_attention_mask
=
None
...
...
@@ -463,6 +474,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
training
=
training
)
hidden_states
=
layer_outputs
[
0
]
if
i
==
0
:
# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
position_bias
=
layer_outputs
[
2
if
self
.
output_attentions
else
1
]
if
self
.
is_decoder
:
encoder_decoder_position_bias
=
layer_outputs
[
4
if
self
.
output_attentions
else
2
]
...
...
@@ -502,8 +515,8 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
@
property
def
dummy_inputs
(
self
):
input_ids
=
tf
.
constant
(
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
)
input_mask
=
tf
.
constant
(
[[
1
,
1
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
1
,
0
,
0
,
1
,
1
]]
)
input_ids
=
tf
.
constant
(
DUMMY_INPUTS
)
input_mask
=
tf
.
constant
(
DUMMY_MASK
)
dummy_inputs
=
{
'decoder_input_ids'
:
input_ids
,
'encoder_input_ids'
:
input_ids
,
'decoder_attention_mask'
:
input_mask
}
...
...
transformers/modeling_tf_utils.py
View file @
fafd4c86
...
...
@@ -24,13 +24,11 @@ import os
import
tensorflow
as
tf
from
.configuration_utils
import
PretrainedConfig
from
.file_utils
import
cached_path
,
WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
from
.file_utils
import
cached_path
,
WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
,
DUMMY_INPUTS
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
logger
=
logging
.
getLogger
(
__name__
)
DUMMY_INPUTS
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
class
TFPreTrainedModel
(
tf
.
keras
.
Model
):
r
""" Base class for all TF models.
...
...
@@ -59,7 +57,7 @@ class TFPreTrainedModel(tf.keras.Model):
Returns:
tf.Tensor with dummy inputs
"""
return
tf
.
constant
(
DUMMY_INPUTS
)
return
{
'input_ids'
:
tf
.
constant
(
DUMMY_INPUTS
)
}
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
TFPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
...
...
transformers/modeling_utils.py
View file @
fafd4c86
...
...
@@ -31,11 +31,10 @@ from torch.nn import CrossEntropyLoss
from
torch.nn
import
functional
as
F
from
.configuration_utils
import
PretrainedConfig
from
.file_utils
import
cached_path
,
WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
from
.file_utils
import
cached_path
,
WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
,
DUMMY_INPUTS
logger
=
logging
.
getLogger
(
__name__
)
try
:
from
torch.nn
import
Identity
except
ImportError
:
...
...
@@ -71,6 +70,15 @@ class PreTrainedModel(nn.Module):
load_tf_weights
=
lambda
model
,
config
,
path
:
None
base_model_prefix
=
""
@
property
def
dummy_inputs
(
self
):
""" Dummy inputs to do a forward pass in the network.
Returns:
torch.Tensor with dummy inputs
"""
return
{
'input_ids'
:
torch
.
tensor
(
DUMMY_INPUTS
)}
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
PreTrainedModel
,
self
).
__init__
()
if
not
isinstance
(
config
,
PretrainedConfig
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment