Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c19b8e4a
Commit
c19b8e4a
authored
Oct 09, 2019
by
thomwolf
Browse files
fixing CTRL tests and OpenAI GPT tests
parent
6dce6dda
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
31 additions
and
25 deletions
+31
-25
transformers/modeling_ctrl.py
transformers/modeling_ctrl.py
+24
-20
transformers/modeling_openai.py
transformers/modeling_openai.py
+1
-1
transformers/modeling_tf_ctrl.py
transformers/modeling_tf_ctrl.py
+2
-1
transformers/tests/modeling_tf_common_test.py
transformers/tests/modeling_tf_common_test.py
+4
-3
No files found.
transformers/modeling_ctrl.py
View file @
c19b8e4a
...
...
@@ -303,11 +303,6 @@ class CTRLModel(CTRLPreTrainedModel):
def
forward
(
self
,
input_ids
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
):
input_shape
=
input_ids
.
size
()
input_ids
=
input_ids
.
view
(
-
1
,
input_shape
[
-
1
])
if
token_type_ids
is
not
None
:
token_type_ids
=
token_type_ids
.
view
(
-
1
,
input_shape
[
-
1
])
if
position_ids
is
not
None
:
position_ids
=
position_ids
.
view
(
-
1
,
input_shape
[
-
1
])
if
past
is
None
:
past_length
=
0
past
=
[
None
]
*
len
(
self
.
h
)
...
...
@@ -349,42 +344,51 @@ class CTRLModel(CTRLPreTrainedModel):
else
:
head_mask
=
[
None
]
*
self
.
config
.
n_layer
x
=
self
.
w
(
input_ids
)
# x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
if
token_type_ids
is
not
None
:
token_type_ids
=
token_type_ids
.
view
(
-
1
,
input_shape
[
-
1
])
token_type_embeds
=
self
.
w
(
token_type_ids
)
token_type_embeds
*=
np
.
sqrt
(
self
.
d_model_size
)
else
:
token_type_embeds
=
0
position_ids
=
position_ids
.
view
(
-
1
,
input_shape
[
-
1
])
inputs_embeds
=
self
.
w
(
input_ids
)
# inputs_embeds = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
seq_len
=
input_ids
.
shape
[
-
1
]
mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
),
1
).
to
(
x
.
device
)
mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
),
1
).
to
(
inputs_embeds
.
device
)
inputs_embeds
*=
np
.
sqrt
(
self
.
d_model_size
)
x
*=
np
.
sqrt
(
self
.
d_model_siz
e
)
pos_embeds
=
self
.
pos_encoding
[
position_ids
,
:].
to
(
inputs_embeds
.
devic
e
)
pos_x
=
self
.
pos_encoding
[
position_ids
,
:].
to
(
x
.
device
)
x
+=
pos_x
hidden_states
=
inputs_embeds
+
pos_embeds
+
token_type_embeds
x
=
self
.
dropout
(
x
)
hidden_states
=
self
.
dropout
(
hidden_states
)
output_shape
=
input_shape
+
(
x
.
size
(
-
1
),)
output_shape
=
input_shape
+
(
inputs_embeds
.
size
(
-
1
),)
presents
=
()
all_hidden_states
=
()
all_attentions
=
[]
for
i
,
(
h
,
layer_past
)
in
enumerate
(
zip
(
self
.
h
,
past
)):
if
self
.
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
x
.
view
(
*
output_shape
),)
outputs
=
h
(
x
,
all_hidden_states
=
all_hidden_states
+
(
hidden_states
.
view
(
*
output_shape
),)
outputs
=
h
(
hidden_states
,
mask
,
layer_past
=
layer_past
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
[
i
])
x
,
present
=
outputs
[:
2
]
hidden_states
,
present
=
outputs
[:
2
]
presents
=
presents
+
(
present
,)
if
self
.
output_attentions
:
all_attentions
.
append
(
outputs
[
2
])
x
=
self
.
layernorm
(
x
)
x
=
x
.
view
(
*
output_shape
)
hidden_states
=
self
.
layernorm
(
hidden_states
)
hidden_states
=
hidden_states
.
view
(
*
output_shape
)
if
self
.
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
x
,)
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,)
outputs
=
(
x
,
presents
)
outputs
=
(
hidden_states
,
presents
)
if
self
.
output_hidden_states
:
outputs
=
outputs
+
(
all_hidden_states
,)
if
self
.
output_attentions
:
...
...
transformers/modeling_openai.py
View file @
c19b8e4a
...
...
@@ -170,7 +170,7 @@ class Attention(nn.Module):
# w = w * self.bias + -1e9 * (1 - self.bias) # TF implem method: mask_attn_weights
# XD: self.b may be larger than w, so we need to crop it
b
=
self
.
bias
[:,
:,
:
w
.
size
(
-
2
),
:
w
.
size
(
-
1
)]
w
=
w
*
b
+
-
1e
9
*
(
1
-
b
)
w
=
w
*
b
+
-
1e
4
*
(
1
-
b
)
if
attention_mask
is
not
None
:
# Apply the attention mask
...
...
transformers/modeling_tf_ctrl.py
View file @
c19b8e4a
...
...
@@ -238,6 +238,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
past_length
=
shape_list
(
past
[
0
][
0
])[
-
2
]
if
position_ids
is
None
:
position_ids
=
tf
.
range
(
past_length
,
shape_list
(
input_ids
)[
-
1
]
+
past_length
,
dtype
=
tf
.
int32
)[
tf
.
newaxis
,
:]
position_ids
=
tf
.
tile
(
position_ids
,
[
shape_list
(
input_ids
)[
0
],
1
])
# Attention mask.
if
attention_mask
is
not
None
:
...
...
@@ -276,7 +277,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
token_type_embeds
=
0
position_ids
=
tf
.
reshape
(
position_ids
,
[
-
1
,
shape_list
(
position_ids
)[
-
1
]])
inputs_embeds
=
self
.
w
(
input_ids
)
inputs_embeds
=
self
.
w
(
input_ids
,
mode
=
'embedding'
)
# x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
seq_len
=
input_shape
[
-
1
]
mask
=
1
-
tf
.
linalg
.
band_part
(
tf
.
ones
((
seq_len
,
seq_len
)),
-
1
,
0
)
...
...
transformers/tests/modeling_tf_common_test.py
View file @
c19b8e4a
...
...
@@ -81,8 +81,9 @@ class TFCommonTestCases:
pt_model_class_name
=
model_class
.
__name__
[
2
:]
# Skip the "TF" at the beggining
pt_model_class
=
getattr
(
transformers
,
pt_model_class_name
)
tf_model
=
model_class
(
config
,
output_hidden_states
=
True
)
pt_model
=
pt_model_class
(
config
,
output_hidden_states
=
True
)
config
.
output_hidden_states
=
True
tf_model
=
model_class
(
config
)
pt_model
=
pt_model_class
(
config
)
# Check we can load pt model in tf and vice-versa (architecture similar)
tf_model
=
transformers
.
load_pytorch_model_in_tf2_model
(
tf_model
,
pt_model
,
tf_inputs
=
inputs_dict
)
...
...
@@ -96,7 +97,7 @@ class TFCommonTestCases:
pto
=
pt_model
(
**
pt_inputs_dict
)
tfo
=
tf_model
(
inputs_dict
)
max_diff
=
np
.
amax
(
np
.
abs
(
tfo
[
0
].
numpy
()
-
pto
[
0
].
numpy
()))
self
.
assertLessEqual
(
max_diff
,
2e-
2
)
self
.
assertLessEqual
(
max_diff
,
2e-
5
)
def
test_keyword_and_dict_args
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment