Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3edfa1d6
Commit
3edfa1d6
authored
Oct 08, 2019
by
thomwolf
Browse files
update model to use past
parent
bd5363cc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
11 deletions
+19
-11
transformers/modeling_ctrl.py
transformers/modeling_ctrl.py
+16
-10
transformers/tests/modeling_ctrl_test.py
transformers/tests/modeling_ctrl_test.py
+3
-1
No files found.
transformers/modeling_ctrl.py
View file @
3edfa1d6
...
@@ -52,7 +52,7 @@ def positional_encoding(position, d_model_size, dtype):
...
@@ -52,7 +52,7 @@ def positional_encoding(position, d_model_size, dtype):
sines
=
torch
.
sin
(
angle_rads
[:,
0
::
2
])
sines
=
torch
.
sin
(
angle_rads
[:,
0
::
2
])
cosines
=
torch
.
cos
(
angle_rads
[:,
1
::
2
])
cosines
=
torch
.
cos
(
angle_rads
[:,
1
::
2
])
pos_encoding
=
torch
.
cat
([
sines
,
cosines
],
dim
=-
1
)
.
unsqueeze
(
0
)
pos_encoding
=
torch
.
cat
([
sines
,
cosines
],
dim
=-
1
)
return
pos_encoding
return
pos_encoding
def
scaled_dot_product_attention
(
q
,
k
,
v
,
mask
,
attention_mask
=
None
,
head_mask
=
None
):
def
scaled_dot_product_attention
(
q
,
k
,
v
,
mask
,
attention_mask
=
None
,
head_mask
=
None
):
...
@@ -110,18 +110,21 @@ class MultiHeadAttention(torch.nn.Module):
...
@@ -110,18 +110,21 @@ class MultiHeadAttention(torch.nn.Module):
k
=
self
.
split_into_heads
(
k
,
batch_size
)
k
=
self
.
split_into_heads
(
k
,
batch_size
)
v
=
self
.
split_into_heads
(
v
,
batch_size
)
v
=
self
.
split_into_heads
(
v
,
batch_size
)
if
layer_past
is
not
None
:
if
layer_past
is
not
None
:
past_key
,
past_value
=
layer_past
[
0
]
.
transpose
(
-
2
,
-
1
),
layer_past
[
1
]
# transpose back cf below
past_key
,
past_value
=
layer_past
[
0
]
,
layer_past
[
1
]
k
=
torch
.
cat
((
past_key
,
k
),
dim
=-
1
)
k
=
torch
.
cat
((
past_key
,
k
),
dim
=-
1
)
v
=
torch
.
cat
((
past_value
,
v
),
dim
=-
2
)
v
=
torch
.
cat
((
past_value
,
v
),
dim
=-
2
)
present
=
torch
.
stack
((
k
.
transpose
(
-
2
,
-
1
),
v
))
# transpose to have same shapes for stacking
present
=
torch
.
stack
((
k
,
v
))
output
=
scaled_dot_product_attention
(
q
,
k
,
v
,
mask
,
attention_mask
,
head_mask
,
output_attentions
)
output
=
scaled_dot_product_attention
(
q
,
k
,
v
,
mask
,
attention_mask
,
head_mask
)
scaled_attention
=
output
[
0
].
permute
([
0
,
2
,
1
,
3
])
scaled_attention
=
output
[
0
].
permute
([
0
,
2
,
1
,
3
])
attn
=
output
[
1
]
attn
=
output
[
1
]
original_size_attention
=
scaled_attention
.
reshape
(
batch_size
,
-
1
,
self
.
d_model_size
)
original_size_attention
=
scaled_attention
.
reshape
(
batch_size
,
-
1
,
self
.
d_model_size
)
output
=
self
.
dense
(
original_size_attention
)
output
=
self
.
dense
(
original_size_attention
)
return
output
,
attn
outputs
=
(
output
,
present
)
if
self
.
output_attentions
:
outputs
=
outputs
+
(
attn
,)
return
outputs
...
@@ -146,10 +149,11 @@ class EncoderLayer(torch.nn.Module):
...
@@ -146,10 +149,11 @@ class EncoderLayer(torch.nn.Module):
def
forward
(
self
,
x
,
mask
,
layer_past
=
None
,
attention_mask
=
None
,
head_mask
=
None
):
def
forward
(
self
,
x
,
mask
,
layer_past
=
None
,
attention_mask
=
None
,
head_mask
=
None
):
normed
=
self
.
layernorm1
(
x
)
normed
=
self
.
layernorm1
(
x
)
attn_output
,
attn
=
self
.
multi_head_attention
(
normed
,
normed
,
normed
,
mask
,
attn_output
s
=
self
.
multi_head_attention
(
normed
,
normed
,
normed
,
mask
,
layer_past
=
layer_past
,
layer_past
=
layer_past
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
)
head_mask
=
head_mask
)
attn_output
=
attn_outputs
[
0
]
attn_output
=
self
.
dropout1
(
attn_output
)
attn_output
=
self
.
dropout1
(
attn_output
)
out1
=
x
+
attn_output
out1
=
x
+
attn_output
...
@@ -158,7 +162,8 @@ class EncoderLayer(torch.nn.Module):
...
@@ -158,7 +162,8 @@ class EncoderLayer(torch.nn.Module):
ffn_output
=
self
.
dropout2
(
ffn_output
)
ffn_output
=
self
.
dropout2
(
ffn_output
)
out2
=
out1
+
ffn_output
out2
=
out1
+
ffn_output
return
out2
,
attn
outputs
=
(
out2
,)
+
attn_outputs
[
1
:]
return
outputs
class
CTRLPreTrainedModel
(
PreTrainedModel
):
class
CTRLPreTrainedModel
(
PreTrainedModel
):
...
@@ -344,14 +349,15 @@ class CTRLModel(CTRLPreTrainedModel):
...
@@ -344,14 +349,15 @@ class CTRLModel(CTRLPreTrainedModel):
else
:
else
:
head_mask
=
[
None
]
*
self
.
config
.
n_layer
head_mask
=
[
None
]
*
self
.
config
.
n_layer
embedded
=
self
.
w
(
input_ids
)
x
=
self
.
w
(
input_ids
)
x
=
embedded
.
unsqueeze
(
0
)
if
len
(
input_ids
.
shape
)
<
2
else
embedded
#
x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
seq_len
=
input_ids
.
shape
[
1
]
seq_len
=
input_ids
.
shape
[
1
]
mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
),
1
).
to
(
x
.
device
)
mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
),
1
).
to
(
x
.
device
)
x
*=
np
.
sqrt
(
self
.
d_model_size
)
x
*=
np
.
sqrt
(
self
.
d_model_size
)
x
+=
self
.
pos_encoding
[:,
position_ids
,
:].
to
(
x
.
device
)
pos_x
=
self
.
pos_encoding
[
position_ids
,
:].
to
(
x
.
device
)
x
+=
pos_x
x
=
self
.
dropout
(
x
)
x
=
self
.
dropout
(
x
)
...
...
transformers/tests/modeling_ctrl_test.py
View file @
3edfa1d6
...
@@ -144,14 +144,16 @@ class CTRLModelTest(CommonTestCases.CommonModelTester):
...
@@ -144,14 +144,16 @@ class CTRLModelTest(CommonTestCases.CommonModelTester):
model
(
input_ids
,
token_type_ids
=
token_type_ids
,
head_mask
=
head_mask
)
model
(
input_ids
,
token_type_ids
=
token_type_ids
,
head_mask
=
head_mask
)
model
(
input_ids
,
token_type_ids
=
token_type_ids
)
model
(
input_ids
,
token_type_ids
=
token_type_ids
)
sequence_output
,
_
=
model
(
input_ids
)
sequence_output
,
presents
=
model
(
input_ids
)
result
=
{
result
=
{
"sequence_output"
:
sequence_output
,
"sequence_output"
:
sequence_output
,
"presents"
:
presents
,
}
}
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"sequence_output"
].
size
()),
list
(
result
[
"sequence_output"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertEqual
(
len
(
result
[
"presents"
]),
config
.
n_layer
)
def
create_and_check_lm_head_model
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
def
create_and_check_lm_head_model
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
model
=
CTRLLMHeadModel
(
config
)
model
=
CTRLLMHeadModel
(
config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment