Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3edfa1d6
Commit
3edfa1d6
authored
Oct 08, 2019
by
thomwolf
Browse files
update model to use past
parent
bd5363cc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
11 deletions
+19
-11
transformers/modeling_ctrl.py
transformers/modeling_ctrl.py
+16
-10
transformers/tests/modeling_ctrl_test.py
transformers/tests/modeling_ctrl_test.py
+3
-1
No files found.
transformers/modeling_ctrl.py
View file @
3edfa1d6
...
...
@@ -52,7 +52,7 @@ def positional_encoding(position, d_model_size, dtype):
sines
=
torch
.
sin
(
angle_rads
[:,
0
::
2
])
cosines
=
torch
.
cos
(
angle_rads
[:,
1
::
2
])
pos_encoding
=
torch
.
cat
([
sines
,
cosines
],
dim
=-
1
)
.
unsqueeze
(
0
)
pos_encoding
=
torch
.
cat
([
sines
,
cosines
],
dim
=-
1
)
return
pos_encoding
def
scaled_dot_product_attention
(
q
,
k
,
v
,
mask
,
attention_mask
=
None
,
head_mask
=
None
):
...
...
@@ -110,18 +110,21 @@ class MultiHeadAttention(torch.nn.Module):
k
=
self
.
split_into_heads
(
k
,
batch_size
)
v
=
self
.
split_into_heads
(
v
,
batch_size
)
if
layer_past
is
not
None
:
past_key
,
past_value
=
layer_past
[
0
]
.
transpose
(
-
2
,
-
1
),
layer_past
[
1
]
# transpose back cf below
past_key
,
past_value
=
layer_past
[
0
]
,
layer_past
[
1
]
k
=
torch
.
cat
((
past_key
,
k
),
dim
=-
1
)
v
=
torch
.
cat
((
past_value
,
v
),
dim
=-
2
)
present
=
torch
.
stack
((
k
.
transpose
(
-
2
,
-
1
),
v
))
# transpose to have same shapes for stacking
present
=
torch
.
stack
((
k
,
v
))
output
=
scaled_dot_product_attention
(
q
,
k
,
v
,
mask
,
attention_mask
,
head_mask
,
output_attentions
)
output
=
scaled_dot_product_attention
(
q
,
k
,
v
,
mask
,
attention_mask
,
head_mask
)
scaled_attention
=
output
[
0
].
permute
([
0
,
2
,
1
,
3
])
attn
=
output
[
1
]
original_size_attention
=
scaled_attention
.
reshape
(
batch_size
,
-
1
,
self
.
d_model_size
)
output
=
self
.
dense
(
original_size_attention
)
return
output
,
attn
outputs
=
(
output
,
present
)
if
self
.
output_attentions
:
outputs
=
outputs
+
(
attn
,)
return
outputs
...
...
@@ -146,10 +149,11 @@ class EncoderLayer(torch.nn.Module):
def
forward
(
self
,
x
,
mask
,
layer_past
=
None
,
attention_mask
=
None
,
head_mask
=
None
):
normed
=
self
.
layernorm1
(
x
)
attn_output
,
attn
=
self
.
multi_head_attention
(
normed
,
normed
,
normed
,
mask
,
attn_output
s
=
self
.
multi_head_attention
(
normed
,
normed
,
normed
,
mask
,
layer_past
=
layer_past
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
)
attn_output
=
attn_outputs
[
0
]
attn_output
=
self
.
dropout1
(
attn_output
)
out1
=
x
+
attn_output
...
...
@@ -158,7 +162,8 @@ class EncoderLayer(torch.nn.Module):
ffn_output
=
self
.
dropout2
(
ffn_output
)
out2
=
out1
+
ffn_output
return
out2
,
attn
outputs
=
(
out2
,)
+
attn_outputs
[
1
:]
return
outputs
class
CTRLPreTrainedModel
(
PreTrainedModel
):
...
...
@@ -344,14 +349,15 @@ class CTRLModel(CTRLPreTrainedModel):
else
:
head_mask
=
[
None
]
*
self
.
config
.
n_layer
embedded
=
self
.
w
(
input_ids
)
x
=
embedded
.
unsqueeze
(
0
)
if
len
(
input_ids
.
shape
)
<
2
else
embedded
x
=
self
.
w
(
input_ids
)
#
x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
seq_len
=
input_ids
.
shape
[
1
]
mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
),
1
).
to
(
x
.
device
)
x
*=
np
.
sqrt
(
self
.
d_model_size
)
x
+=
self
.
pos_encoding
[:,
position_ids
,
:].
to
(
x
.
device
)
pos_x
=
self
.
pos_encoding
[
position_ids
,
:].
to
(
x
.
device
)
x
+=
pos_x
x
=
self
.
dropout
(
x
)
...
...
transformers/tests/modeling_ctrl_test.py
View file @
3edfa1d6
...
...
@@ -144,14 +144,16 @@ class CTRLModelTest(CommonTestCases.CommonModelTester):
model
(
input_ids
,
token_type_ids
=
token_type_ids
,
head_mask
=
head_mask
)
model
(
input_ids
,
token_type_ids
=
token_type_ids
)
sequence_output
,
_
=
model
(
input_ids
)
sequence_output
,
presents
=
model
(
input_ids
)
result
=
{
"sequence_output"
:
sequence_output
,
"presents"
:
presents
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"sequence_output"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertEqual
(
len
(
result
[
"presents"
]),
config
.
n_layer
)
def
create_and_check_lm_head_model
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
model
=
CTRLLMHeadModel
(
config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment