Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
155c782a
Commit
155c782a
authored
Nov 11, 2019
by
Julien Chaumond
Browse files
[inputs_embeds] All TF models + tests
parent
2aef2f0b
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
252 additions
and
105 deletions
+252
-105
transformers/modeling_tf_bert.py
transformers/modeling_tf_bert.py
+41
-18
transformers/modeling_tf_ctrl.py
transformers/modeling_tf_ctrl.py
+18
-9
transformers/modeling_tf_distilbert.py
transformers/modeling_tf_distilbert.py
+27
-11
transformers/modeling_tf_gpt2.py
transformers/modeling_tf_gpt2.py
+30
-14
transformers/modeling_tf_openai.py
transformers/modeling_tf_openai.py
+30
-14
transformers/modeling_tf_roberta.py
transformers/modeling_tf_roberta.py
+7
-3
transformers/modeling_tf_transfo_xl.py
transformers/modeling_tf_transfo_xl.py
+35
-18
transformers/modeling_tf_xlm.py
transformers/modeling_tf_xlm.py
+23
-8
transformers/modeling_tf_xlnet.py
transformers/modeling_tf_xlnet.py
+20
-6
transformers/tests/modeling_tf_bert_test.py
transformers/tests/modeling_tf_bert_test.py
+0
-4
transformers/tests/modeling_tf_common_test.py
transformers/tests/modeling_tf_common_test.py
+21
-0
No files found.
transformers/modeling_tf_bert.py
View file @
155c782a
...
@@ -142,19 +142,25 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
...
@@ -142,19 +142,25 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
def
_embedding
(
self
,
inputs
,
training
=
False
):
def
_embedding
(
self
,
inputs
,
training
=
False
):
"""Applies embedding based on inputs tensor."""
"""Applies embedding based on inputs tensor."""
input_ids
,
position_ids
,
token_type_ids
=
inputs
input_ids
,
position_ids
,
token_type_ids
,
inputs_embeds
=
inputs
seq_length
=
tf
.
shape
(
input_ids
)[
1
]
if
input_ids
is
not
None
:
input_shape
=
tf
.
shape
(
input_ids
)
else
:
input_shape
=
tf
.
shape
(
inputs_embeds
)[:
-
1
]
seq_length
=
input_shape
[
1
]
if
position_ids
is
None
:
if
position_ids
is
None
:
position_ids
=
tf
.
range
(
seq_length
,
dtype
=
tf
.
int32
)[
tf
.
newaxis
,
:]
position_ids
=
tf
.
range
(
seq_length
,
dtype
=
tf
.
int32
)[
tf
.
newaxis
,
:]
if
token_type_ids
is
None
:
if
token_type_ids
is
None
:
token_type_ids
=
tf
.
fill
(
tf
.
shape
(
input_
ids
)
,
0
)
token_type_ids
=
tf
.
fill
(
input_
shape
,
0
)
words_embeddings
=
tf
.
gather
(
self
.
word_embeddings
,
input_ids
)
if
inputs_embeds
is
None
:
inputs_embeds
=
tf
.
gather
(
self
.
word_embeddings
,
input_ids
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
token_type_embeddings
=
self
.
token_type_embeddings
(
token_type_ids
)
token_type_embeddings
=
self
.
token_type_embeddings
(
token_type_ids
)
embeddings
=
word
s_embed
ding
s
+
position_embeddings
+
token_type_embeddings
embeddings
=
input
s_embeds
+
position_embeddings
+
token_type_embeddings
embeddings
=
self
.
LayerNorm
(
embeddings
)
embeddings
=
self
.
LayerNorm
(
embeddings
)
embeddings
=
self
.
dropout
(
embeddings
,
training
=
training
)
embeddings
=
self
.
dropout
(
embeddings
,
training
=
training
)
return
embeddings
return
embeddings
...
@@ -473,28 +479,39 @@ class TFBertMainLayer(tf.keras.layers.Layer):
...
@@ -473,28 +479,39 @@ class TFBertMainLayer(tf.keras.layers.Layer):
"""
"""
raise
NotImplementedError
raise
NotImplementedError
def
call
(
self
,
inputs
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
training
=
False
):
def
call
(
self
,
inputs
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
training
=
False
):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
input_ids
=
inputs
[
0
]
attention_mask
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
attention_mask
attention_mask
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
attention_mask
token_type_ids
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
token_type_ids
token_type_ids
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
token_type_ids
position_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
position_ids
position_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
position_ids
head_mask
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
head_mask
head_mask
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
head_mask
assert
len
(
inputs
)
<=
5
,
"Too many inputs."
inputs_embeds
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
inputs_embeds
assert
len
(
inputs
)
<=
6
,
"Too many inputs."
elif
isinstance
(
inputs
,
dict
):
elif
isinstance
(
inputs
,
dict
):
input_ids
=
inputs
.
get
(
'input_ids'
)
input_ids
=
inputs
.
get
(
'input_ids'
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
attention_mask
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
attention_mask
)
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
token_type_ids
)
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
token_type_ids
)
position_ids
=
inputs
.
get
(
'position_ids'
,
position_ids
)
position_ids
=
inputs
.
get
(
'position_ids'
,
position_ids
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
assert
len
(
inputs
)
<=
5
,
"Too many inputs."
inputs_embeds
=
inputs
.
get
(
'inputs_embeds'
,
inputs_embeds
)
assert
len
(
inputs
)
<=
6
,
"Too many inputs."
else
:
else
:
input_ids
=
inputs
input_ids
=
inputs
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif
input_ids
is
not
None
:
input_shape
=
input_ids
.
shape
elif
inputs_embeds
is
not
None
:
input_shape
=
inputs_embeds
.
shape
[:
-
1
]
else
:
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
if
attention_mask
is
None
:
if
attention_mask
is
None
:
attention_mask
=
tf
.
fill
(
tf
.
shape
(
input_
ids
)
,
1
)
attention_mask
=
tf
.
fill
(
input_
shape
,
1
)
if
token_type_ids
is
None
:
if
token_type_ids
is
None
:
token_type_ids
=
tf
.
fill
(
tf
.
shape
(
input_
ids
)
,
0
)
token_type_ids
=
tf
.
fill
(
input_
shape
,
0
)
# We create a 3D attention mask from a 2D tensor mask.
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# Sizes are [batch_size, 1, 1, to_seq_length]
...
@@ -523,7 +540,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
...
@@ -523,7 +540,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
head_mask
=
[
None
]
*
self
.
num_hidden_layers
head_mask
=
[
None
]
*
self
.
num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers)
# head_mask = tf.constant([0] * self.num_hidden_layers)
embedding_output
=
self
.
embeddings
([
input_ids
,
position_ids
,
token_type_ids
],
training
=
training
)
embedding_output
=
self
.
embeddings
([
input_ids
,
position_ids
,
token_type_ids
,
inputs_embeds
],
training
=
training
)
encoder_outputs
=
self
.
encoder
([
embedding_output
,
extended_attention_mask
,
head_mask
],
training
=
training
)
encoder_outputs
=
self
.
encoder
([
embedding_output
,
extended_attention_mask
,
head_mask
],
training
=
training
)
sequence_output
=
encoder_outputs
[
0
]
sequence_output
=
encoder_outputs
[
0
]
...
@@ -901,33 +918,39 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
...
@@ -901,33 +918,39 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
'classifier'
)
name
=
'classifier'
)
def
call
(
self
,
inputs
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
training
=
False
):
def
call
(
self
,
inputs
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
training
=
False
):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
input_ids
=
inputs
[
0
]
attention_mask
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
attention_mask
attention_mask
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
attention_mask
token_type_ids
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
token_type_ids
token_type_ids
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
token_type_ids
position_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
position_ids
position_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
position_ids
head_mask
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
head_mask
head_mask
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
head_mask
assert
len
(
inputs
)
<=
5
,
"Too many inputs."
inputs_embeds
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
inputs_embeds
assert
len
(
inputs
)
<=
6
,
"Too many inputs."
elif
isinstance
(
inputs
,
dict
):
elif
isinstance
(
inputs
,
dict
):
input_ids
=
inputs
.
get
(
'input_ids'
)
input_ids
=
inputs
.
get
(
'input_ids'
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
attention_mask
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
attention_mask
)
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
token_type_ids
)
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
token_type_ids
)
position_ids
=
inputs
.
get
(
'position_ids'
,
position_ids
)
position_ids
=
inputs
.
get
(
'position_ids'
,
position_ids
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
assert
len
(
inputs
)
<=
5
,
"Too many inputs."
inputs_embeds
=
inputs
.
get
(
'inputs_embeds'
,
inputs_embeds
)
assert
len
(
inputs
)
<=
6
,
"Too many inputs."
else
:
else
:
input_ids
=
inputs
input_ids
=
inputs
if
input_ids
is
not
None
:
num_choices
=
tf
.
shape
(
input_ids
)[
1
]
num_choices
=
tf
.
shape
(
input_ids
)[
1
]
seq_length
=
tf
.
shape
(
input_ids
)[
2
]
seq_length
=
tf
.
shape
(
input_ids
)[
2
]
else
:
num_choices
=
tf
.
shape
(
inputs_embeds
)[
1
]
seq_length
=
tf
.
shape
(
inputs_embeds
)[
2
]
flat_input_ids
=
tf
.
reshape
(
input_ids
,
(
-
1
,
seq_length
))
flat_input_ids
=
tf
.
reshape
(
input_ids
,
(
-
1
,
seq_length
))
if
input_ids
is
not
None
else
None
flat_attention_mask
=
tf
.
reshape
(
attention_mask
,
(
-
1
,
seq_length
))
if
attention_mask
is
not
None
else
None
flat_attention_mask
=
tf
.
reshape
(
attention_mask
,
(
-
1
,
seq_length
))
if
attention_mask
is
not
None
else
None
flat_token_type_ids
=
tf
.
reshape
(
token_type_ids
,
(
-
1
,
seq_length
))
if
token_type_ids
is
not
None
else
None
flat_token_type_ids
=
tf
.
reshape
(
token_type_ids
,
(
-
1
,
seq_length
))
if
token_type_ids
is
not
None
else
None
flat_position_ids
=
tf
.
reshape
(
position_ids
,
(
-
1
,
seq_length
))
if
position_ids
is
not
None
else
None
flat_position_ids
=
tf
.
reshape
(
position_ids
,
(
-
1
,
seq_length
))
if
position_ids
is
not
None
else
None
flat_inputs
=
[
flat_input_ids
,
flat_attention_mask
,
flat_token_type_ids
,
flat_position_ids
,
head_mask
]
flat_inputs
=
[
flat_input_ids
,
flat_attention_mask
,
flat_token_type_ids
,
flat_position_ids
,
head_mask
,
inputs_embeds
]
outputs
=
self
.
bert
(
flat_inputs
,
training
=
training
)
outputs
=
self
.
bert
(
flat_inputs
,
training
=
training
)
...
...
transformers/modeling_tf_ctrl.py
View file @
155c782a
...
@@ -204,7 +204,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
...
@@ -204,7 +204,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
"""
"""
raise
NotImplementedError
raise
NotImplementedError
def
call
(
self
,
inputs
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
training
=
False
):
def
call
(
self
,
inputs
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
training
=
False
):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
input_ids
=
inputs
[
0
]
past
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
past
past
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
past
...
@@ -212,7 +212,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
...
@@ -212,7 +212,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
token_type_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
token_type_ids
token_type_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
token_type_ids
position_ids
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
position_ids
position_ids
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
position_ids
head_mask
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
head_mask
head_mask
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
head_mask
assert
len
(
inputs
)
<=
6
,
"Too many inputs."
inputs_embeds
=
inputs
[
6
]
if
len
(
inputs
)
>
6
else
inputs_embeds
assert
len
(
inputs
)
<=
7
,
"Too many inputs."
elif
isinstance
(
inputs
,
dict
):
elif
isinstance
(
inputs
,
dict
):
input_ids
=
inputs
.
get
(
'input_ids'
)
input_ids
=
inputs
.
get
(
'input_ids'
)
past
=
inputs
.
get
(
'past'
,
past
)
past
=
inputs
.
get
(
'past'
,
past
)
...
@@ -220,12 +221,20 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
...
@@ -220,12 +221,20 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
token_type_ids
)
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
token_type_ids
)
position_ids
=
inputs
.
get
(
'position_ids'
,
position_ids
)
position_ids
=
inputs
.
get
(
'position_ids'
,
position_ids
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
assert
len
(
inputs
)
<=
6
,
"Too many inputs."
inputs_embeds
=
inputs
.
get
(
'inputs_embeds'
,
inputs_embeds
)
assert
len
(
inputs
)
<=
7
,
"Too many inputs."
else
:
else
:
input_ids
=
inputs
input_ids
=
inputs
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif
input_ids
is
not
None
:
input_shape
=
shape_list
(
input_ids
)
input_shape
=
shape_list
(
input_ids
)
input_ids
=
tf
.
reshape
(
input_ids
,
[
-
1
,
input_shape
[
-
1
]])
input_ids
=
tf
.
reshape
(
input_ids
,
[
-
1
,
input_shape
[
-
1
]])
elif
inputs_embeds
is
not
None
:
input_shape
=
shape_list
(
inputs_embeds
)[:
-
1
]
else
:
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
if
past
is
None
:
if
past
is
None
:
past_length
=
0
past_length
=
0
...
@@ -233,8 +242,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
...
@@ -233,8 +242,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
else
:
else
:
past_length
=
shape_list
(
past
[
0
][
0
])[
-
2
]
past_length
=
shape_list
(
past
[
0
][
0
])[
-
2
]
if
position_ids
is
None
:
if
position_ids
is
None
:
position_ids
=
tf
.
range
(
past_length
,
shape_list
(
input_
ids
)
[
-
1
]
+
past_length
,
dtype
=
tf
.
int32
)[
tf
.
newaxis
,
:]
position_ids
=
tf
.
range
(
past_length
,
input_
shape
[
-
1
]
+
past_length
,
dtype
=
tf
.
int32
)[
tf
.
newaxis
,
:]
position_ids
=
tf
.
tile
(
position_ids
,
[
shape_list
(
input_
ids
)
[
0
],
1
])
position_ids
=
tf
.
tile
(
position_ids
,
[
input_
shape
[
0
],
1
])
# Attention mask.
# Attention mask.
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
...
@@ -273,8 +282,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
...
@@ -273,8 +282,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
token_type_embeds
=
0
token_type_embeds
=
0
position_ids
=
tf
.
reshape
(
position_ids
,
[
-
1
,
shape_list
(
position_ids
)[
-
1
]])
position_ids
=
tf
.
reshape
(
position_ids
,
[
-
1
,
shape_list
(
position_ids
)[
-
1
]])
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
w
(
input_ids
,
mode
=
'embedding'
)
inputs_embeds
=
self
.
w
(
input_ids
,
mode
=
'embedding'
)
# x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
seq_len
=
input_shape
[
-
1
]
seq_len
=
input_shape
[
-
1
]
mask
=
1
-
tf
.
linalg
.
band_part
(
tf
.
ones
((
seq_len
,
seq_len
)),
-
1
,
0
)
mask
=
1
-
tf
.
linalg
.
band_part
(
tf
.
ones
((
seq_len
,
seq_len
)),
-
1
,
0
)
...
...
transformers/modeling_tf_distilbert.py
View file @
155c782a
...
@@ -96,7 +96,7 @@ class TFEmbeddings(tf.keras.layers.Layer):
...
@@ -96,7 +96,7 @@ class TFEmbeddings(tf.keras.layers.Layer):
initializer
=
get_initializer
(
self
.
initializer_range
))
initializer
=
get_initializer
(
self
.
initializer_range
))
super
(
TFEmbeddings
,
self
).
build
(
input_shape
)
super
(
TFEmbeddings
,
self
).
build
(
input_shape
)
def
call
(
self
,
inputs
,
mode
=
"embedding"
,
training
=
False
):
def
call
(
self
,
inputs
,
inputs_embeds
=
None
,
mode
=
"embedding"
,
training
=
False
):
"""Get token embeddings of inputs.
"""Get token embeddings of inputs.
Args:
Args:
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
...
@@ -112,13 +112,13 @@ class TFEmbeddings(tf.keras.layers.Layer):
...
@@ -112,13 +112,13 @@ class TFEmbeddings(tf.keras.layers.Layer):
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
"""
if
mode
==
"embedding"
:
if
mode
==
"embedding"
:
return
self
.
_embedding
(
inputs
,
training
=
training
)
return
self
.
_embedding
(
inputs
,
inputs_embeds
=
inputs_embeds
,
training
=
training
)
elif
mode
==
"linear"
:
elif
mode
==
"linear"
:
return
self
.
_linear
(
inputs
)
return
self
.
_linear
(
inputs
)
else
:
else
:
raise
ValueError
(
"mode {} is not valid."
.
format
(
mode
))
raise
ValueError
(
"mode {} is not valid."
.
format
(
mode
))
def
_embedding
(
self
,
inputs
,
training
=
False
):
def
_embedding
(
self
,
inputs
,
inputs_embeds
=
None
,
training
=
False
):
"""
"""
Parameters
Parameters
----------
----------
...
@@ -136,14 +136,19 @@ class TFEmbeddings(tf.keras.layers.Layer):
...
@@ -136,14 +136,19 @@ class TFEmbeddings(tf.keras.layers.Layer):
else
:
else
:
input_ids
,
position_ids
=
inputs
input_ids
,
position_ids
=
inputs
if
input_ids
is
not
None
:
seq_length
=
tf
.
shape
(
input_ids
)[
1
]
seq_length
=
tf
.
shape
(
input_ids
)[
1
]
else
:
seq_length
=
tf
.
shape
(
inputs_embeds
)[
1
]
if
position_ids
is
None
:
if
position_ids
is
None
:
position_ids
=
tf
.
range
(
seq_length
,
dtype
=
tf
.
int32
)[
tf
.
newaxis
,
:]
position_ids
=
tf
.
range
(
seq_length
,
dtype
=
tf
.
int32
)[
tf
.
newaxis
,
:]
word_embeddings
=
tf
.
gather
(
self
.
word_embeddings
,
input_ids
)
if
inputs_embeds
is
None
:
inputs_embeds
=
tf
.
gather
(
self
.
word_embeddings
,
input_ids
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
# (bs, max_seq_length, dim)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
# (bs, max_seq_length, dim)
embeddings
=
word
_embed
ding
s
+
position_embeddings
# (bs, max_seq_length, dim)
embeddings
=
inputs
_embeds
+
position_embeddings
# (bs, max_seq_length, dim)
embeddings
=
self
.
LayerNorm
(
embeddings
)
# (bs, max_seq_length, dim)
embeddings
=
self
.
LayerNorm
(
embeddings
)
# (bs, max_seq_length, dim)
embeddings
=
self
.
dropout
(
embeddings
,
training
=
training
)
# (bs, max_seq_length, dim)
embeddings
=
self
.
dropout
(
embeddings
,
training
=
training
)
# (bs, max_seq_length, dim)
return
embeddings
return
embeddings
...
@@ -407,22 +412,33 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
...
@@ -407,22 +412,33 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
def
_prune_heads
(
self
,
heads_to_prune
):
def
_prune_heads
(
self
,
heads_to_prune
):
raise
NotImplementedError
raise
NotImplementedError
def
call
(
self
,
inputs
,
attention_mask
=
None
,
head_mask
=
None
,
training
=
False
):
def
call
(
self
,
inputs
,
attention_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
training
=
False
):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
input_ids
=
inputs
[
0
]
attention_mask
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
attention_mask
attention_mask
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
attention_mask
head_mask
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
head_mask
head_mask
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
head_mask
assert
len
(
inputs
)
<=
3
,
"Too many inputs."
inputs_embeds
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
inputs_embeds
assert
len
(
inputs
)
<=
4
,
"Too many inputs."
elif
isinstance
(
inputs
,
dict
):
elif
isinstance
(
inputs
,
dict
):
input_ids
=
inputs
.
get
(
'input_ids'
)
input_ids
=
inputs
.
get
(
'input_ids'
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
attention_mask
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
attention_mask
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
assert
len
(
inputs
)
<=
3
,
"Too many inputs."
inputs_embeds
=
inputs
.
get
(
'inputs_embeds'
,
inputs_embeds
)
assert
len
(
inputs
)
<=
4
,
"Too many inputs."
else
:
else
:
input_ids
=
inputs
input_ids
=
inputs
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif
input_ids
is
not
None
:
input_shape
=
shape_list
(
input_ids
)
elif
inputs_embeds
is
not
None
:
input_shape
=
shape_list
(
inputs_embeds
)[:
-
1
]
else
:
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
if
attention_mask
is
None
:
if
attention_mask
is
None
:
attention_mask
=
tf
.
ones
(
shape_list
(
input_
ids
)
)
# (bs, seq_length)
attention_mask
=
tf
.
ones
(
input_
shape
)
# (bs, seq_length)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
tf
.
float32
)
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
tf
.
float32
)
# Prepare head mask if needed
# Prepare head mask if needed
...
@@ -435,7 +451,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
...
@@ -435,7 +451,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
else
:
else
:
head_mask
=
[
None
]
*
self
.
num_hidden_layers
head_mask
=
[
None
]
*
self
.
num_hidden_layers
embedding_output
=
self
.
embeddings
(
input_ids
)
# (bs, seq_length, dim)
embedding_output
=
self
.
embeddings
(
input_ids
,
inputs_embeds
=
inputs_embeds
)
# (bs, seq_length, dim)
tfmr_output
=
self
.
transformer
([
embedding_output
,
attention_mask
,
head_mask
],
training
=
training
)
tfmr_output
=
self
.
transformer
([
embedding_output
,
attention_mask
,
head_mask
],
training
=
training
)
return
tfmr_output
# last-layer hidden-state, (all hidden_states), (all attentions)
return
tfmr_output
# last-layer hidden-state, (all hidden_states), (all attentions)
...
...
transformers/modeling_tf_gpt2.py
View file @
155c782a
...
@@ -231,7 +231,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
...
@@ -231,7 +231,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
"""
"""
raise
NotImplementedError
raise
NotImplementedError
def
call
(
self
,
inputs
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
training
=
False
):
def
call
(
self
,
inputs
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
training
=
False
):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
input_ids
=
inputs
[
0
]
past
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
past
past
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
past
...
@@ -239,7 +239,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
...
@@ -239,7 +239,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
token_type_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
token_type_ids
token_type_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
token_type_ids
position_ids
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
position_ids
position_ids
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
position_ids
head_mask
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
head_mask
head_mask
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
head_mask
assert
len
(
inputs
)
<=
6
,
"Too many inputs."
inputs_embeds
=
inputs
[
6
]
if
len
(
inputs
)
>
6
else
inputs_embeds
assert
len
(
inputs
)
<=
7
,
"Too many inputs."
elif
isinstance
(
inputs
,
dict
):
elif
isinstance
(
inputs
,
dict
):
input_ids
=
inputs
.
get
(
'input_ids'
)
input_ids
=
inputs
.
get
(
'input_ids'
)
past
=
inputs
.
get
(
'past'
,
past
)
past
=
inputs
.
get
(
'past'
,
past
)
...
@@ -247,17 +248,28 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
...
@@ -247,17 +248,28 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
token_type_ids
)
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
token_type_ids
)
position_ids
=
inputs
.
get
(
'position_ids'
,
position_ids
)
position_ids
=
inputs
.
get
(
'position_ids'
,
position_ids
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
assert
len
(
inputs
)
<=
6
,
"Too many inputs."
inputs_embeds
=
inputs
.
get
(
'inputs_embeds'
,
inputs_embeds
)
assert
len
(
inputs
)
<=
7
,
"Too many inputs."
else
:
else
:
input_ids
=
inputs
input_ids
=
inputs
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif
input_ids
is
not
None
:
input_shape
=
shape_list
(
input_ids
)
input_ids
=
tf
.
reshape
(
input_ids
,
[
-
1
,
input_shape
[
-
1
]])
elif
inputs_embeds
is
not
None
:
input_shape
=
shape_list
(
inputs_embeds
)[:
-
1
]
else
:
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
if
past
is
None
:
if
past
is
None
:
past_length
=
0
past_length
=
0
past
=
[
None
]
*
len
(
self
.
h
)
past
=
[
None
]
*
len
(
self
.
h
)
else
:
else
:
past_length
=
shape_list
(
past
[
0
][
0
])[
-
2
]
past_length
=
shape_list
(
past
[
0
][
0
])[
-
2
]
if
position_ids
is
None
:
if
position_ids
is
None
:
position_ids
=
tf
.
range
(
past_length
,
shape_list
(
input_
ids
)
[
-
1
]
+
past_length
,
dtype
=
tf
.
int32
)[
tf
.
newaxis
,
:]
position_ids
=
tf
.
range
(
past_length
,
input_
shape
[
-
1
]
+
past_length
,
dtype
=
tf
.
int32
)[
tf
.
newaxis
,
:]
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
# We create a 3D attention mask from a 2D tensor mask.
# We create a 3D attention mask from a 2D tensor mask.
...
@@ -289,10 +301,9 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
...
@@ -289,10 +301,9 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
head_mask
=
[
None
]
*
self
.
num_hidden_layers
head_mask
=
[
None
]
*
self
.
num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers)
# head_mask = tf.constant([0] * self.num_hidden_layers)
input_shape
=
shape_list
(
input_ids
)
input_ids
=
tf
.
reshape
(
input_ids
,
[
-
1
,
input_shape
[
-
1
]])
position_ids
=
tf
.
reshape
(
position_ids
,
[
-
1
,
shape_list
(
position_ids
)[
-
1
]])
position_ids
=
tf
.
reshape
(
position_ids
,
[
-
1
,
shape_list
(
position_ids
)[
-
1
]])
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
wte
(
input_ids
,
mode
=
'embedding'
)
inputs_embeds
=
self
.
wte
(
input_ids
,
mode
=
'embedding'
)
position_embeds
=
self
.
wpe
(
position_ids
)
position_embeds
=
self
.
wpe
(
position_ids
)
if
token_type_ids
is
not
None
:
if
token_type_ids
is
not
None
:
...
@@ -569,7 +580,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
...
@@ -569,7 +580,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
def
get_output_embeddings
(
self
):
def
get_output_embeddings
(
self
):
return
self
.
transformer
.
wte
return
self
.
transformer
.
wte
def
call
(
self
,
inputs
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
mc_token_ids
=
None
,
training
=
False
):
def
call
(
self
,
inputs
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
mc_token_ids
=
None
,
training
=
False
):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
input_ids
=
inputs
[
0
]
past
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
past
past
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
past
...
@@ -577,8 +588,9 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
...
@@ -577,8 +588,9 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
token_type_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
token_type_ids
token_type_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
token_type_ids
position_ids
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
position_ids
position_ids
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
position_ids
head_mask
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
head_mask
head_mask
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
head_mask
mc_token_ids
=
inputs
[
6
]
if
len
(
inputs
)
>
6
else
mc_token_ids
inputs_embeds
=
inputs
[
6
]
if
len
(
inputs
)
>
6
else
inputs_embeds
assert
len
(
inputs
)
<=
7
,
"Too many inputs."
mc_token_ids
=
inputs
[
7
]
if
len
(
inputs
)
>
7
else
mc_token_ids
assert
len
(
inputs
)
<=
8
,
"Too many inputs."
elif
isinstance
(
inputs
,
dict
):
elif
isinstance
(
inputs
,
dict
):
input_ids
=
inputs
.
get
(
'input_ids'
)
input_ids
=
inputs
.
get
(
'input_ids'
)
past
=
inputs
.
get
(
'past'
,
past
)
past
=
inputs
.
get
(
'past'
,
past
)
...
@@ -586,21 +598,25 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
...
@@ -586,21 +598,25 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
token_type_ids
)
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
token_type_ids
)
position_ids
=
inputs
.
get
(
'position_ids'
,
position_ids
)
position_ids
=
inputs
.
get
(
'position_ids'
,
position_ids
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
inputs_embeds
=
inputs
.
get
(
'inputs_embeds'
,
inputs_embeds
)
mc_token_ids
=
inputs
.
get
(
'mc_token_ids'
,
mc_token_ids
)
mc_token_ids
=
inputs
.
get
(
'mc_token_ids'
,
mc_token_ids
)
assert
len
(
inputs
)
<=
7
,
"Too many inputs."
assert
len
(
inputs
)
<=
8
,
"Too many inputs."
else
:
else
:
input_ids
=
inputs
input_ids
=
inputs
if
input_ids
is
not
None
:
input_shapes
=
shape_list
(
input_ids
)
input_shapes
=
shape_list
(
input_ids
)
else
:
input_shapes
=
shape_list
(
inputs_embeds
)[:
-
1
]
seq_length
=
input_shapes
[
-
1
]
seq_length
=
input_shapes
[
-
1
]
flat_input_ids
=
tf
.
reshape
(
input_ids
,
(
-
1
,
seq_length
))
flat_input_ids
=
tf
.
reshape
(
input_ids
,
(
-
1
,
seq_length
))
if
input_ids
is
not
None
else
None
flat_attention_mask
=
tf
.
reshape
(
attention_mask
,
(
-
1
,
seq_length
))
if
attention_mask
is
not
None
else
None
flat_attention_mask
=
tf
.
reshape
(
attention_mask
,
(
-
1
,
seq_length
))
if
attention_mask
is
not
None
else
None
flat_token_type_ids
=
tf
.
reshape
(
token_type_ids
,
(
-
1
,
seq_length
))
if
token_type_ids
is
not
None
else
None
flat_token_type_ids
=
tf
.
reshape
(
token_type_ids
,
(
-
1
,
seq_length
))
if
token_type_ids
is
not
None
else
None
flat_position_ids
=
tf
.
reshape
(
position_ids
,
(
-
1
,
seq_length
))
if
position_ids
is
not
None
else
None
flat_position_ids
=
tf
.
reshape
(
position_ids
,
(
-
1
,
seq_length
))
if
position_ids
is
not
None
else
None
flat_inputs
=
[
flat_input_ids
,
past
,
flat_attention_mask
,
flat_token_type_ids
,
flat_position_ids
,
head_mask
]
flat_inputs
=
[
flat_input_ids
,
past
,
flat_attention_mask
,
flat_token_type_ids
,
flat_position_ids
,
head_mask
,
inputs_embeds
]
transformer_outputs
=
self
.
transformer
(
flat_inputs
,
training
=
training
)
transformer_outputs
=
self
.
transformer
(
flat_inputs
,
training
=
training
)
hidden_states
=
transformer_outputs
[
0
]
hidden_states
=
transformer_outputs
[
0
]
...
...
transformers/modeling_tf_openai.py
View file @
155c782a
...
@@ -229,26 +229,38 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
...
@@ -229,26 +229,38 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
"""
"""
raise
NotImplementedError
raise
NotImplementedError
def
call
(
self
,
inputs
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
training
=
False
):
def
call
(
self
,
inputs
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
training
=
False
):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
input_ids
=
inputs
[
0
]
attention_mask
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
attention_mask
attention_mask
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
attention_mask
token_type_ids
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
token_type_ids
token_type_ids
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
token_type_ids
position_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
position_ids
position_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
position_ids
head_mask
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
head_mask
head_mask
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
head_mask
assert
len
(
inputs
)
<=
5
,
"Too many inputs."
inputs_embeds
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
inputs_embeds
assert
len
(
inputs
)
<=
6
,
"Too many inputs."
elif
isinstance
(
inputs
,
dict
):
elif
isinstance
(
inputs
,
dict
):
input_ids
=
inputs
.
get
(
'input_ids'
)
input_ids
=
inputs
.
get
(
'input_ids'
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
attention_mask
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
attention_mask
)
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
token_type_ids
)
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
token_type_ids
)
position_ids
=
inputs
.
get
(
'position_ids'
,
position_ids
)
position_ids
=
inputs
.
get
(
'position_ids'
,
position_ids
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
assert
len
(
inputs
)
<=
5
,
"Too many inputs."
inputs_embeds
=
inputs
.
get
(
'inputs_embeds'
,
inputs_embeds
)
assert
len
(
inputs
)
<=
6
,
"Too many inputs."
else
:
else
:
input_ids
=
inputs
input_ids
=
inputs
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif
input_ids
is
not
None
:
input_shape
=
shape_list
(
input_ids
)
input_ids
=
tf
.
reshape
(
input_ids
,
[
-
1
,
input_shape
[
-
1
]])
elif
inputs_embeds
is
not
None
:
input_shape
=
shape_list
(
inputs_embeds
)[:
-
1
]
else
:
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
if
position_ids
is
None
:
if
position_ids
is
None
:
position_ids
=
tf
.
range
(
shape_list
(
input_
ids
)
[
-
1
],
dtype
=
tf
.
int32
)[
tf
.
newaxis
,
:]
position_ids
=
tf
.
range
(
input_
shape
[
-
1
],
dtype
=
tf
.
int32
)[
tf
.
newaxis
,
:]
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
# We create a 3D attention mask from a 2D tensor mask.
# We create a 3D attention mask from a 2D tensor mask.
...
@@ -280,10 +292,9 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
...
@@ -280,10 +292,9 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
head_mask
=
[
None
]
*
self
.
num_hidden_layers
head_mask
=
[
None
]
*
self
.
num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers)
# head_mask = tf.constant([0] * self.num_hidden_layers)
input_shape
=
shape_list
(
input_ids
)
input_ids
=
tf
.
reshape
(
input_ids
,
[
-
1
,
input_shape
[
-
1
]])
position_ids
=
tf
.
reshape
(
position_ids
,
[
-
1
,
shape_list
(
position_ids
)[
-
1
]])
position_ids
=
tf
.
reshape
(
position_ids
,
[
-
1
,
shape_list
(
position_ids
)[
-
1
]])
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
tokens_embed
(
input_ids
,
mode
=
'embedding'
)
inputs_embeds
=
self
.
tokens_embed
(
input_ids
,
mode
=
'embedding'
)
position_embeds
=
self
.
positions_embed
(
position_ids
)
position_embeds
=
self
.
positions_embed
(
position_ids
)
if
token_type_ids
is
not
None
:
if
token_type_ids
is
not
None
:
...
@@ -533,36 +544,41 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
...
@@ -533,36 +544,41 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
def
get_output_embeddings
(
self
):
def
get_output_embeddings
(
self
):
return
self
.
transformer
.
tokens_embed
return
self
.
transformer
.
tokens_embed
def
call
(
self
,
inputs
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
mc_token_ids
=
None
,
training
=
False
):
def
call
(
self
,
inputs
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
mc_token_ids
=
None
,
training
=
False
):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
input_ids
=
inputs
[
0
]
attention_mask
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
attention_mask
attention_mask
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
attention_mask
token_type_ids
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
token_type_ids
token_type_ids
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
token_type_ids
position_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
position_ids
position_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
position_ids
head_mask
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
head_mask
head_mask
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
head_mask
mc_token_ids
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
mc_token_ids
inputs_embeds
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
inputs_embeds
assert
len
(
inputs
)
<=
6
,
"Too many inputs."
mc_token_ids
=
inputs
[
6
]
if
len
(
inputs
)
>
6
else
mc_token_ids
assert
len
(
inputs
)
<=
7
,
"Too many inputs."
elif
isinstance
(
inputs
,
dict
):
elif
isinstance
(
inputs
,
dict
):
input_ids
=
inputs
.
get
(
'input_ids'
)
input_ids
=
inputs
.
get
(
'input_ids'
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
attention_mask
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
attention_mask
)
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
token_type_ids
)
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
token_type_ids
)
position_ids
=
inputs
.
get
(
'position_ids'
,
position_ids
)
position_ids
=
inputs
.
get
(
'position_ids'
,
position_ids
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
inputs_embeds
=
inputs
.
get
(
'inputs_embeds'
,
inputs_embeds
)
mc_token_ids
=
inputs
.
get
(
'mc_token_ids'
,
mc_token_ids
)
mc_token_ids
=
inputs
.
get
(
'mc_token_ids'
,
mc_token_ids
)
assert
len
(
inputs
)
<=
6
,
"Too many inputs."
assert
len
(
inputs
)
<=
7
,
"Too many inputs."
else
:
else
:
input_ids
=
inputs
input_ids
=
inputs
if
input_ids
is
not
None
:
input_shapes
=
shape_list
(
input_ids
)
input_shapes
=
shape_list
(
input_ids
)
else
:
input_shapes
=
shape_list
(
inputs_embeds
)[:
-
1
]
seq_length
=
input_shapes
[
-
1
]
seq_length
=
input_shapes
[
-
1
]
flat_input_ids
=
tf
.
reshape
(
input_ids
,
(
-
1
,
seq_length
))
flat_input_ids
=
tf
.
reshape
(
input_ids
,
(
-
1
,
seq_length
))
if
input_ids
is
not
None
else
None
flat_attention_mask
=
tf
.
reshape
(
attention_mask
,
(
-
1
,
seq_length
))
if
attention_mask
is
not
None
else
None
flat_attention_mask
=
tf
.
reshape
(
attention_mask
,
(
-
1
,
seq_length
))
if
attention_mask
is
not
None
else
None
flat_token_type_ids
=
tf
.
reshape
(
token_type_ids
,
(
-
1
,
seq_length
))
if
token_type_ids
is
not
None
else
None
flat_token_type_ids
=
tf
.
reshape
(
token_type_ids
,
(
-
1
,
seq_length
))
if
token_type_ids
is
not
None
else
None
flat_position_ids
=
tf
.
reshape
(
position_ids
,
(
-
1
,
seq_length
))
if
position_ids
is
not
None
else
None
flat_position_ids
=
tf
.
reshape
(
position_ids
,
(
-
1
,
seq_length
))
if
position_ids
is
not
None
else
None
flat_inputs
=
[
flat_input_ids
,
flat_attention_mask
,
flat_token_type_ids
,
flat_position_ids
,
head_mask
]
flat_inputs
=
[
flat_input_ids
,
flat_attention_mask
,
flat_token_type_ids
,
flat_position_ids
,
head_mask
,
inputs_embeds
]
transformer_outputs
=
self
.
transformer
(
flat_inputs
,
training
=
training
)
transformer_outputs
=
self
.
transformer
(
flat_inputs
,
training
=
training
)
hidden_states
=
transformer_outputs
[
0
]
hidden_states
=
transformer_outputs
[
0
]
...
...
transformers/modeling_tf_roberta.py
View file @
155c782a
...
@@ -48,13 +48,17 @@ class TFRobertaEmbeddings(TFBertEmbeddings):
...
@@ -48,13 +48,17 @@ class TFRobertaEmbeddings(TFBertEmbeddings):
def
_embedding
(
self
,
inputs
,
training
=
False
):
def
_embedding
(
self
,
inputs
,
training
=
False
):
"""Applies embedding based on inputs tensor."""
"""Applies embedding based on inputs tensor."""
input_ids
,
position_ids
,
token_type_ids
=
inputs
input_ids
,
position_ids
,
token_type_ids
,
inputs_embeds
=
inputs
if
input_ids
is
not
None
:
seq_length
=
tf
.
shape
(
input_ids
)[
1
]
seq_length
=
tf
.
shape
(
input_ids
)[
1
]
else
:
seq_length
=
tf
.
shape
(
inputs_embeds
)[
1
]
if
position_ids
is
None
:
if
position_ids
is
None
:
position_ids
=
tf
.
range
(
self
.
padding_idx
+
1
,
seq_length
+
self
.
padding_idx
+
1
,
dtype
=
tf
.
int32
)[
tf
.
newaxis
,
:]
position_ids
=
tf
.
range
(
self
.
padding_idx
+
1
,
seq_length
+
self
.
padding_idx
+
1
,
dtype
=
tf
.
int32
)[
tf
.
newaxis
,
:]
return
super
(
TFRobertaEmbeddings
,
self
).
_embedding
([
input_ids
,
position_ids
,
token_type_ids
],
training
=
training
)
return
super
(
TFRobertaEmbeddings
,
self
).
_embedding
([
input_ids
,
position_ids
,
token_type_ids
,
inputs_embeds
],
training
=
training
)
class
TFRobertaMainLayer
(
TFBertMainLayer
):
class
TFRobertaMainLayer
(
TFBertMainLayer
):
...
...
transformers/modeling_tf_transfo_xl.py
View file @
155c782a
...
@@ -430,11 +430,11 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
...
@@ -430,11 +430,11 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
def
_prune_heads
(
self
,
heads
):
def
_prune_heads
(
self
,
heads
):
raise
NotImplementedError
raise
NotImplementedError
def
init_mems
(
self
,
data
):
def
init_mems
(
self
,
bsz
):
if
self
.
mem_len
>
0
:
if
self
.
mem_len
>
0
:
mems
=
[]
mems
=
[]
for
i
in
range
(
self
.
n_layer
):
for
i
in
range
(
self
.
n_layer
):
empty
=
tf
.
zeros
([
self
.
mem_len
,
shape_list
(
data
)[
1
]
,
self
.
d_model
])
empty
=
tf
.
zeros
([
self
.
mem_len
,
bsz
,
self
.
d_model
])
mems
.
append
(
empty
)
mems
.
append
(
empty
)
return
mems
return
mems
...
@@ -464,28 +464,37 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
...
@@ -464,28 +464,37 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
return
new_mems
return
new_mems
def
call
(
self
,
inputs
,
mems
=
None
,
head_mask
=
None
,
training
=
False
):
def
call
(
self
,
inputs
,
mems
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
training
=
False
):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
input_ids
=
inputs
[
0
]
mems
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
mems
mems
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
mems
head_mask
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
head_mask
head_mask
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
head_mask
assert
len
(
inputs
)
<=
3
,
"Too many inputs."
inputs_embeds
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
inputs_embeds
assert
len
(
inputs
)
<=
4
,
"Too many inputs."
elif
isinstance
(
inputs
,
dict
):
elif
isinstance
(
inputs
,
dict
):
input_ids
=
inputs
.
get
(
'input_ids'
)
input_ids
=
inputs
.
get
(
'input_ids'
)
mems
=
inputs
.
get
(
'mems'
,
mems
)
mems
=
inputs
.
get
(
'mems'
,
mems
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
assert
len
(
inputs
)
<=
3
,
"Too many inputs."
inputs_embeds
=
inputs
.
get
(
'inputs_embeds'
,
inputs_embeds
)
assert
len
(
inputs
)
<=
4
,
"Too many inputs."
else
:
else
:
input_ids
=
inputs
input_ids
=
inputs
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# so we transpose here from shape [bsz, len] to shape [len, bsz]
# so we transpose here from shape [bsz, len] to shape [len, bsz]
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif
input_ids
is
not
None
:
input_ids
=
tf
.
transpose
(
input_ids
,
perm
=
(
1
,
0
))
input_ids
=
tf
.
transpose
(
input_ids
,
perm
=
(
1
,
0
))
qlen
,
bsz
=
shape_list
(
input_ids
)
elif
inputs_embeds
is
not
None
:
inputs_embeds
=
tf
.
transpose
(
inputs_embeds
,
perm
=
(
1
,
0
,
2
))
qlen
,
bsz
=
shape_list
(
inputs_embeds
)[:
2
]
else
:
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
if
mems
is
None
:
if
mems
is
None
:
mems
=
self
.
init_mems
(
input_ids
)
mems
=
self
.
init_mems
(
bsz
)
qlen
,
bsz
=
shape_list
(
input_ids
)
# Prepare head mask if needed
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# 1.0 in head_mask indicate we keep the head
...
@@ -497,6 +506,9 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
...
@@ -497,6 +506,9 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
else
:
else
:
head_mask
=
[
None
]
*
self
.
n_layer
head_mask
=
[
None
]
*
self
.
n_layer
if
inputs_embeds
is
not
None
:
word_emb
=
inputs_embeds
else
:
word_emb
=
self
.
word_emb
(
input_ids
)
word_emb
=
self
.
word_emb
(
input_ids
)
mlen
=
shape_list
(
mems
[
0
])[
0
]
if
mems
is
not
None
else
0
mlen
=
shape_list
(
mems
[
0
])[
0
]
if
mems
is
not
None
else
0
...
@@ -723,28 +735,33 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
...
@@ -723,28 +735,33 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
def
reset_length
(
self
,
tgt_len
,
ext_len
,
mem_len
):
def
reset_length
(
self
,
tgt_len
,
ext_len
,
mem_len
):
self
.
transformer
.
reset_length
(
tgt_len
,
ext_len
,
mem_len
)
self
.
transformer
.
reset_length
(
tgt_len
,
ext_len
,
mem_len
)
def
init_mems
(
self
,
data
):
def
init_mems
(
self
,
bsz
):
return
self
.
transformer
.
init_mems
(
data
)
return
self
.
transformer
.
init_mems
(
bsz
)
def
call
(
self
,
inputs
,
mems
=
None
,
head_mask
=
None
,
labels
=
None
,
training
=
False
):
def
call
(
self
,
inputs
,
mems
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
labels
=
None
,
training
=
False
):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
input_ids
=
inputs
[
0
]
mems
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
mems
mems
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
mems
head_mask
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
head_mask
head_mask
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
head_mask
labels
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
labels
inputs_embeds
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
inputs_embeds
assert
len
(
inputs
)
<=
4
,
"Too many inputs."
labels
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
labels
assert
len
(
inputs
)
<=
5
,
"Too many inputs."
elif
isinstance
(
inputs
,
dict
):
elif
isinstance
(
inputs
,
dict
):
input_ids
=
inputs
.
get
(
'input_ids'
)
input_ids
=
inputs
.
get
(
'input_ids'
)
mems
=
inputs
.
get
(
'mems'
,
mems
)
mems
=
inputs
.
get
(
'mems'
,
mems
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
inputs_embeds
=
inputs
.
get
(
'inputs_embeds'
,
inputs_embeds
)
labels
=
inputs
.
get
(
'labels'
,
labels
)
labels
=
inputs
.
get
(
'labels'
,
labels
)
assert
len
(
inputs
)
<=
4
,
"Too many inputs."
assert
len
(
inputs
)
<=
5
,
"Too many inputs."
else
:
else
:
input_ids
=
inputs
input_ids
=
inputs
if
input_ids
is
not
None
:
bsz
,
tgt_len
=
shape_list
(
input_ids
)[:
2
]
bsz
,
tgt_len
=
shape_list
(
input_ids
)[:
2
]
else
:
bsz
,
tgt_len
=
shape_list
(
inputs_embeds
)[:
2
]
transformer_outputs
=
self
.
transformer
([
input_ids
,
mems
,
head_mask
],
training
=
training
)
transformer_outputs
=
self
.
transformer
([
input_ids
,
mems
,
head_mask
,
inputs_embeds
],
training
=
training
)
last_hidden
=
transformer_outputs
[
0
]
last_hidden
=
transformer_outputs
[
0
]
pred_hid
=
last_hidden
[:,
-
tgt_len
:]
pred_hid
=
last_hidden
[:,
-
tgt_len
:]
...
...
transformers/modeling_tf_xlm.py
View file @
155c782a
...
@@ -291,7 +291,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
...
@@ -291,7 +291,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
raise
NotImplementedError
raise
NotImplementedError
def
call
(
self
,
inputs
,
attention_mask
=
None
,
langs
=
None
,
token_type_ids
=
None
,
def
call
(
self
,
inputs
,
attention_mask
=
None
,
langs
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
lengths
=
None
,
cache
=
None
,
head_mask
=
None
,
position_ids
=
None
,
lengths
=
None
,
cache
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
training
=
False
):
# removed: src_enc=None, src_len=None
training
=
False
):
# removed: src_enc=None, src_len=None
if
isinstance
(
inputs
,
(
tuple
,
list
)):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
input_ids
=
inputs
[
0
]
...
@@ -302,7 +302,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
...
@@ -302,7 +302,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
lengths
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
lengths
lengths
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
lengths
cache
=
inputs
[
6
]
if
len
(
inputs
)
>
6
else
cache
cache
=
inputs
[
6
]
if
len
(
inputs
)
>
6
else
cache
head_mask
=
inputs
[
7
]
if
len
(
inputs
)
>
7
else
head_mask
head_mask
=
inputs
[
7
]
if
len
(
inputs
)
>
7
else
head_mask
assert
len
(
inputs
)
<=
8
,
"Too many inputs."
inputs_embeds
=
inputs
[
8
]
if
len
(
inputs
)
>
8
else
inputs_embeds
assert
len
(
inputs
)
<=
9
,
"Too many inputs."
elif
isinstance
(
inputs
,
dict
):
elif
isinstance
(
inputs
,
dict
):
input_ids
=
inputs
.
get
(
'input_ids'
)
input_ids
=
inputs
.
get
(
'input_ids'
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
attention_mask
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
attention_mask
)
...
@@ -312,16 +313,28 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
...
@@ -312,16 +313,28 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
lengths
=
inputs
.
get
(
'lengths'
,
lengths
)
lengths
=
inputs
.
get
(
'lengths'
,
lengths
)
cache
=
inputs
.
get
(
'cache'
,
cache
)
cache
=
inputs
.
get
(
'cache'
,
cache
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
assert
len
(
inputs
)
<=
8
,
"Too many inputs."
inputs_embeds
=
inputs
.
get
(
'inputs_embeds'
,
inputs_embeds
)
assert
len
(
inputs
)
<=
9
,
"Too many inputs."
else
:
else
:
input_ids
=
inputs
input_ids
=
inputs
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif
input_ids
is
not
None
:
bs
,
slen
=
shape_list
(
input_ids
)
elif
inputs_embeds
is
not
None
:
bs
,
slen
=
shape_list
(
inputs_embeds
)[:
2
]
else
:
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
if
lengths
is
None
:
if
lengths
is
None
:
if
input_ids
is
not
None
:
lengths
=
tf
.
reduce_sum
(
tf
.
cast
(
tf
.
not_equal
(
input_ids
,
self
.
pad_index
),
dtype
=
tf
.
int32
),
axis
=
1
)
lengths
=
tf
.
reduce_sum
(
tf
.
cast
(
tf
.
not_equal
(
input_ids
,
self
.
pad_index
),
dtype
=
tf
.
int32
),
axis
=
1
)
else
:
lengths
=
tf
.
convert_to_tensor
([
slen
]
*
bs
,
tf
.
int32
)
# mask = input_ids != self.pad_index
# mask = input_ids != self.pad_index
# check inputs
# check inputs
bs
,
slen
=
shape_list
(
input_ids
)
# assert shape_list(lengths)[0] == bs
# assert shape_list(lengths)[0] == bs
tf
.
debugging
.
assert_equal
(
shape_list
(
lengths
)[
0
],
bs
)
tf
.
debugging
.
assert_equal
(
shape_list
(
lengths
)[
0
],
bs
)
# assert lengths.max().item() <= slen
# assert lengths.max().item() <= slen
...
@@ -361,7 +374,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
...
@@ -361,7 +374,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
head_mask
=
[
None
]
*
self
.
n_layers
head_mask
=
[
None
]
*
self
.
n_layers
# do not recompute cached elements
# do not recompute cached elements
if
cache
is
not
None
:
if
cache
is
not
None
and
input_ids
is
not
None
:
_slen
=
slen
-
cache
[
'slen'
]
_slen
=
slen
-
cache
[
'slen'
]
input_ids
=
input_ids
[:,
-
_slen
:]
input_ids
=
input_ids
[:,
-
_slen
:]
position_ids
=
position_ids
[:,
-
_slen
:]
position_ids
=
position_ids
[:,
-
_slen
:]
...
@@ -371,8 +384,10 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
...
@@ -371,8 +384,10 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
attn_mask
=
attn_mask
[:,
-
_slen
:]
attn_mask
=
attn_mask
[:,
-
_slen
:]
# embeddings
# embeddings
tensor
=
self
.
embeddings
(
input_ids
)
if
inputs_embeds
is
None
:
tensor
=
tensor
+
self
.
position_embeddings
(
position_ids
)
inputs_embeds
=
self
.
embeddings
(
input_ids
)
tensor
=
inputs_embeds
+
self
.
position_embeddings
(
position_ids
)
if
langs
is
not
None
and
self
.
use_lang_emb
:
if
langs
is
not
None
and
self
.
use_lang_emb
:
tensor
=
tensor
+
self
.
lang_embeddings
(
langs
)
tensor
=
tensor
+
self
.
lang_embeddings
(
langs
)
if
token_type_ids
is
not
None
:
if
token_type_ids
is
not
None
:
...
...
transformers/modeling_tf_xlnet.py
View file @
155c782a
...
@@ -487,7 +487,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -487,7 +487,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
return
pos_emb
return
pos_emb
def
call
(
self
,
inputs
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
def
call
(
self
,
inputs
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
token_type_ids
=
None
,
input_mask
=
None
,
head_mask
=
None
,
training
=
False
):
token_type_ids
=
None
,
input_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
training
=
False
):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
input_ids
=
inputs
[
0
]
attention_mask
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
attention_mask
attention_mask
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
attention_mask
...
@@ -497,7 +497,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -497,7 +497,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
token_type_ids
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
token_type_ids
token_type_ids
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
token_type_ids
input_mask
=
inputs
[
6
]
if
len
(
inputs
)
>
6
else
input_mask
input_mask
=
inputs
[
6
]
if
len
(
inputs
)
>
6
else
input_mask
head_mask
=
inputs
[
7
]
if
len
(
inputs
)
>
7
else
head_mask
head_mask
=
inputs
[
7
]
if
len
(
inputs
)
>
7
else
head_mask
assert
len
(
inputs
)
<=
8
,
"Too many inputs."
inputs_embeds
=
inputs
[
8
]
if
len
(
inputs
)
>
8
else
inputs_embeds
assert
len
(
inputs
)
<=
9
,
"Too many inputs."
elif
isinstance
(
inputs
,
dict
):
elif
isinstance
(
inputs
,
dict
):
input_ids
=
inputs
.
get
(
'input_ids'
)
input_ids
=
inputs
.
get
(
'input_ids'
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
attention_mask
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
attention_mask
)
...
@@ -507,7 +508,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -507,7 +508,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
token_type_ids
)
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
token_type_ids
)
input_mask
=
inputs
.
get
(
'input_mask'
,
input_mask
)
input_mask
=
inputs
.
get
(
'input_mask'
,
input_mask
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
assert
len
(
inputs
)
<=
8
,
"Too many inputs."
inputs_embeds
=
inputs
.
get
(
'inputs_embeds'
,
inputs_embeds
)
assert
len
(
inputs
)
<=
9
,
"Too many inputs."
else
:
else
:
input_ids
=
inputs
input_ids
=
inputs
...
@@ -515,14 +517,23 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -515,14 +517,23 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
# but we want a unified interface in the library with the batch size on the first dimension
# but we want a unified interface in the library with the batch size on the first dimension
# so we move here the first dimension (batch) to the end
# so we move here the first dimension (batch) to the end
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif
input_ids
is
not
None
:
input_ids
=
tf
.
transpose
(
input_ids
,
perm
=
(
1
,
0
))
input_ids
=
tf
.
transpose
(
input_ids
,
perm
=
(
1
,
0
))
qlen
,
bsz
=
shape_list
(
input_ids
)[:
2
]
elif
inputs_embeds
is
not
None
:
inputs_embeds
=
tf
.
transpose
(
inputs_embeds
,
perm
=
(
1
,
0
,
2
))
qlen
,
bsz
=
shape_list
(
inputs_embeds
)[:
2
]
else
:
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
token_type_ids
=
tf
.
transpose
(
token_type_ids
,
perm
=
(
1
,
0
))
if
token_type_ids
is
not
None
else
None
token_type_ids
=
tf
.
transpose
(
token_type_ids
,
perm
=
(
1
,
0
))
if
token_type_ids
is
not
None
else
None
input_mask
=
tf
.
transpose
(
input_mask
,
perm
=
(
1
,
0
))
if
input_mask
is
not
None
else
None
input_mask
=
tf
.
transpose
(
input_mask
,
perm
=
(
1
,
0
))
if
input_mask
is
not
None
else
None
attention_mask
=
tf
.
transpose
(
attention_mask
,
perm
=
(
1
,
0
))
if
attention_mask
is
not
None
else
None
attention_mask
=
tf
.
transpose
(
attention_mask
,
perm
=
(
1
,
0
))
if
attention_mask
is
not
None
else
None
perm_mask
=
tf
.
transpose
(
perm_mask
,
perm
=
(
1
,
2
,
0
))
if
perm_mask
is
not
None
else
None
perm_mask
=
tf
.
transpose
(
perm_mask
,
perm
=
(
1
,
2
,
0
))
if
perm_mask
is
not
None
else
None
target_mapping
=
tf
.
transpose
(
target_mapping
,
perm
=
(
1
,
2
,
0
))
if
target_mapping
is
not
None
else
None
target_mapping
=
tf
.
transpose
(
target_mapping
,
perm
=
(
1
,
2
,
0
))
if
target_mapping
is
not
None
else
None
qlen
,
bsz
=
shape_list
(
input_ids
)[:
2
]
mlen
=
shape_list
(
mems
[
0
])[
0
]
if
mems
is
not
None
and
mems
[
0
]
is
not
None
else
0
mlen
=
shape_list
(
mems
[
0
])[
0
]
if
mems
is
not
None
and
mems
[
0
]
is
not
None
else
0
klen
=
mlen
+
qlen
klen
=
mlen
+
qlen
...
@@ -573,6 +584,9 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -573,6 +584,9 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
non_tgt_mask
=
None
non_tgt_mask
=
None
##### Word embeddings and prepare h & g hidden states
##### Word embeddings and prepare h & g hidden states
if
inputs_embeds
is
not
None
:
word_emb_k
=
inputs_embeds
else
:
word_emb_k
=
self
.
word_embedding
(
input_ids
)
word_emb_k
=
self
.
word_embedding
(
input_ids
)
output_h
=
self
.
dropout
(
word_emb_k
,
training
=
training
)
output_h
=
self
.
dropout
(
word_emb_k
,
training
=
training
)
if
target_mapping
is
not
None
:
if
target_mapping
is
not
None
:
...
...
transformers/tests/modeling_tf_bert_test.py
View file @
155c782a
...
@@ -131,10 +131,6 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -131,10 +131,6 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
def
create_and_check_bert_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_and_check_bert_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
TFBertModel
(
config
=
config
)
model
=
TFBertModel
(
config
=
config
)
# inputs = {'input_ids': input_ids,
# 'attention_mask': input_mask,
# 'token_type_ids': token_type_ids}
# sequence_output, pooled_output = model(**inputs)
inputs
=
{
'input_ids'
:
input_ids
,
inputs
=
{
'input_ids'
:
input_ids
,
'attention_mask'
:
input_mask
,
'attention_mask'
:
input_mask
,
'token_type_ids'
:
token_type_ids
}
'token_type_ids'
:
token_type_ids
}
...
...
transformers/tests/modeling_tf_common_test.py
View file @
155c782a
...
@@ -411,6 +411,27 @@ class TFCommonTestCases:
...
@@ -411,6 +411,27 @@ class TFCommonTestCases:
first
,
second
=
model
(
inputs_dict
,
training
=
False
)[
0
],
model
(
inputs_dict
,
training
=
False
)[
0
]
first
,
second
=
model
(
inputs_dict
,
training
=
False
)[
0
],
model
(
inputs_dict
,
training
=
False
)[
0
]
self
.
assertTrue
(
tf
.
math
.
equal
(
first
,
second
).
numpy
().
all
())
self
.
assertTrue
(
tf
.
math
.
equal
(
first
,
second
).
numpy
().
all
())
def
test_inputs_embeds
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
input_ids
=
inputs_dict
[
"input_ids"
]
del
inputs_dict
[
"input_ids"
]
for
model_class
in
self
.
all_model_classes
:
model
=
model_class
(
config
)
wte
=
model
.
get_input_embeddings
()
try
:
x
=
wte
(
input_ids
,
mode
=
"embedding"
)
except
:
try
:
x
=
wte
([
input_ids
],
mode
=
"embedding"
)
except
:
x
=
tf
.
ones
(
input_ids
.
shape
+
[
self
.
model_tester
.
hidden_size
],
dtype
=
tf
.
dtypes
.
float32
)
# ^^ In our TF models, the input_embeddings can take slightly different forms,
# so we try two of them and fall back to just synthetically creating a dummy tensor of ones.
inputs_dict
[
"inputs_embeds"
]
=
x
outputs
=
model
(
inputs_dict
)
def
ids_tensor
(
shape
,
vocab_size
,
rng
=
None
,
name
=
None
,
dtype
=
None
):
def
ids_tensor
(
shape
,
vocab_size
,
rng
=
None
,
name
=
None
,
dtype
=
None
):
"""Creates a random int32 tensor of the shape within the vocab size."""
"""Creates a random int32 tensor of the shape within the vocab size."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment