Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
dc894411
Commit
dc894411
authored
Oct 07, 2019
by
thomwolf
Browse files
update CTRL pytorch model
parent
320b7a7e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
246 additions
and
161 deletions
+246
-161
transformers/modeling_ctrl.py
transformers/modeling_ctrl.py
+246
-161
No files found.
transformers/modeling_ctrl.py
View file @
dc894411
...
...
@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch CTRL model."""
"""
PyTorch CTRL model."""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
...
...
@@ -27,7 +27,6 @@ from io import open
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
pdb
from
torch.nn
import
CrossEntropyLoss
from
torch.nn.parameter
import
Parameter
...
...
@@ -46,7 +45,9 @@ def angle_defn(pos, i, d_model_size):
def
positional_encoding
(
position
,
d_model_size
,
dtype
):
# create the sinusoidal pattern for the positional encoding
angle_rads
=
(
angle_defn
(
torch
.
arange
(
position
,
dtype
=
dtype
).
unsqueeze
(
1
),
torch
.
arange
(
d_model_size
,
dtype
=
dtype
).
unsqueeze
(
0
),
d_model_size
))
angle_rads
=
(
angle_defn
(
torch
.
arange
(
position
,
dtype
=
dtype
).
unsqueeze
(
1
),
torch
.
arange
(
d_model_size
,
dtype
=
dtype
).
unsqueeze
(
0
),
d_model_size
))
sines
=
torch
.
sin
(
angle_rads
[:,
0
::
2
])
cosines
=
torch
.
cos
(
angle_rads
[:,
1
::
2
])
...
...
@@ -54,7 +55,7 @@ def positional_encoding(position, d_model_size, dtype):
pos_encoding
=
torch
.
cat
([
sines
,
cosines
],
dim
=-
1
).
unsqueeze
(
0
)
return
pos_encoding
def
scaled_dot_product_attention
(
q
,
k
,
v
,
mask
):
def
scaled_dot_product_attention
(
q
,
k
,
v
,
mask
,
attention_mask
=
None
,
head_mask
=
None
):
# calculate attention
matmul_qk
=
torch
.
matmul
(
q
,
k
.
permute
(
0
,
1
,
3
,
2
))
...
...
@@ -64,15 +65,25 @@ def scaled_dot_product_attention(q, k, v, mask):
if
mask
is
not
None
:
scaled_attention_logits
+=
(
mask
*
-
1e4
)
if
attention_mask
is
not
None
:
# Apply the attention mask
scaled_attention_logits
=
scaled_attention_logits
+
attention_mask
attention_weights
=
torch
.
softmax
(
scaled_attention_logits
,
dim
=-
1
)
# Mask heads if we want to
if
head_mask
is
not
None
:
attention_weights
=
attention_weights
*
head_mask
output
=
torch
.
matmul
(
attention_weights
,
v
)
return
output
,
attention_weights
class
MultiHeadAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
d_model_size
,
num_heads
):
def
__init__
(
self
,
d_model_size
,
num_heads
,
output_attentions
=
False
):
super
(
MultiHeadAttention
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
num_heads
=
num_heads
self
.
d_model_size
=
d_model_size
...
...
@@ -88,7 +99,7 @@ class MultiHeadAttention(torch.nn.Module):
x
=
x
.
reshape
(
batch_size
,
-
1
,
self
.
num_heads
,
self
.
depth
)
return
x
.
permute
([
0
,
2
,
1
,
3
])
def
forward
(
self
,
v
,
k
,
q
,
mask
):
def
forward
(
self
,
v
,
k
,
q
,
mask
,
layer_past
=
None
,
attention_mask
=
None
,
head_mask
=
None
):
batch_size
=
q
.
shape
[
0
]
q
=
self
.
Wq
(
q
)
...
...
@@ -98,7 +109,13 @@ class MultiHeadAttention(torch.nn.Module):
q
=
self
.
split_into_heads
(
q
,
batch_size
)
k
=
self
.
split_into_heads
(
k
,
batch_size
)
v
=
self
.
split_into_heads
(
v
,
batch_size
)
output
=
scaled_dot_product_attention
(
q
,
k
,
v
,
mask
)
if
layer_past
is
not
None
:
past_key
,
past_value
=
layer_past
[
0
].
transpose
(
-
2
,
-
1
),
layer_past
[
1
]
# transpose back cf below
k
=
torch
.
cat
((
past_key
,
k
),
dim
=-
1
)
v
=
torch
.
cat
((
past_value
,
v
),
dim
=-
2
)
present
=
torch
.
stack
((
k
.
transpose
(
-
2
,
-
1
),
v
))
# transpose to have same shapes for stacking
output
=
scaled_dot_product_attention
(
q
,
k
,
v
,
mask
,
attention_mask
,
head_mask
,
output_attentions
)
scaled_attention
=
output
[
0
].
permute
([
0
,
2
,
1
,
3
])
attn
=
output
[
1
]
original_size_attention
=
scaled_attention
.
reshape
(
batch_size
,
-
1
,
self
.
d_model_size
)
...
...
@@ -109,14 +126,16 @@ class MultiHeadAttention(torch.nn.Module):
def
point_wise_feed_forward_network
(
d_model_size
,
dff
):
return
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
d_model_size
,
dff
),
torch
.
nn
.
ReLU
(),
torch
.
nn
.
Linear
(
dff
,
d_model_size
))
return
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
d_model_size
,
dff
),
torch
.
nn
.
ReLU
(),
torch
.
nn
.
Linear
(
dff
,
d_model_size
))
class
EncoderLayer
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
d_model_size
,
num_heads
,
dff
,
rate
=
0.1
):
def
__init__
(
self
,
d_model_size
,
num_heads
,
dff
,
rate
=
0.1
,
output_attentions
=
False
):
super
(
EncoderLayer
,
self
).
__init__
()
self
.
multi_head_attention
=
MultiHeadAttention
(
d_model_size
,
num_heads
)
self
.
multi_head_attention
=
MultiHeadAttention
(
d_model_size
,
num_heads
,
output_attentions
)
self
.
ffn
=
point_wise_feed_forward_network
(
d_model_size
,
dff
)
self
.
layernorm1
=
torch
.
nn
.
LayerNorm
(
d_model_size
,
eps
=
1e-6
)
...
...
@@ -125,9 +144,12 @@ class EncoderLayer(torch.nn.Module):
self
.
dropout1
=
torch
.
nn
.
Dropout
(
rate
)
self
.
dropout2
=
torch
.
nn
.
Dropout
(
rate
)
def
forward
(
self
,
x
,
mask
):
def
forward
(
self
,
x
,
mask
,
layer_past
=
None
,
attention_mask
=
None
,
head_mask
=
None
):
normed
=
self
.
layernorm1
(
x
)
attn_output
,
attn
=
self
.
multi_head_attention
(
normed
,
normed
,
normed
,
mask
)
attn_output
,
attn
=
self
.
multi_head_attention
(
normed
,
normed
,
normed
,
mask
,
layer_past
=
layer_past
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
)
attn_output
=
self
.
dropout1
(
attn_output
)
out1
=
x
+
attn_output
...
...
@@ -147,9 +169,6 @@ class CTRLPreTrainedModel(PreTrainedModel):
pretrained_model_archive_map
=
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix
=
"transformer"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
CTRLPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
_init_weights
(
self
,
module
):
""" Initialize the weights.
"""
...
...
@@ -256,7 +275,11 @@ class CTRLModel(CTRLPreTrainedModel):
self
.
dropout
=
nn
.
Dropout
(
config
.
embd_pdrop
)
self
.
h
=
nn
.
ModuleList
([
EncoderLayer
(
config
.
n_embd
,
config
.
n_head
,
config
.
dff
,
config
.
resid_pdrop
)
for
_
in
range
(
config
.
n_layer
)])
self
.
h
=
nn
.
ModuleList
([
EncoderLayer
(
config
.
n_embd
,
config
.
n_head
,
config
.
dff
,
config
.
resid_pdrop
,
config
.
output_attentions
)
for
_
in
range
(
config
.
n_layer
)])
self
.
layernorm
=
nn
.
LayerNorm
(
config
.
n_embd
,
eps
=
config
.
layer_norm_epsilon
)
self
.
init_weights
()
...
...
@@ -272,8 +295,54 @@ class CTRLModel(CTRLPreTrainedModel):
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
h
[
layer
].
attn
.
prune_heads
(
heads
)
def
forward
(
self
,
input_ids
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
labels
=
None
):
def
forward
(
self
,
input_ids
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
):
input_shape
=
input_ids
.
size
()
input_ids
=
input_ids
.
view
(
-
1
,
input_shape
[
-
1
])
if
token_type_ids
is
not
None
:
token_type_ids
=
token_type_ids
.
view
(
-
1
,
input_shape
[
-
1
])
if
position_ids
is
not
None
:
position_ids
=
position_ids
.
view
(
-
1
,
input_shape
[
-
1
])
if
past
is
None
:
past_length
=
0
past
=
[
None
]
*
len
(
self
.
h
)
else
:
past_length
=
past
[
0
][
0
].
size
(
-
2
)
if
position_ids
is
None
:
position_ids
=
torch
.
arange
(
past_length
,
input_ids
.
size
(
-
1
)
+
past_length
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
input_ids
)
# Attention mask.
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
.
view
(
-
1
,
input_shape
[
-
1
])
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask
=
attention_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask
=
attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
attention_mask
=
(
1.0
-
attention_mask
)
*
-
10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
self
.
config
.
n_layer
embedded
=
self
.
w
(
input_ids
)
x
=
embedded
.
unsqueeze
(
0
)
if
len
(
input_ids
.
shape
)
<
2
else
embedded
...
...
@@ -282,26 +351,40 @@ class CTRLModel(CTRLPreTrainedModel):
x
*=
np
.
sqrt
(
self
.
d_model_size
)
x
+=
self
.
pos_encoding
[:,
:
seq_len
,
:].
to
(
x
.
device
)
x
+=
self
.
pos_encoding
[:,
position_ids
,
:].
to
(
x
.
device
)
x
=
self
.
dropout
(
x
)
output_shape
=
input_shape
+
(
x
.
size
(
-
1
),)
presents
=
()
all_hidden_states
=
()
all_attentions
=
[]
for
i
in
range
(
self
.
num_layers
):
for
i
,
(
h
,
layer_past
)
in
enumerate
(
zip
(
self
.
h
,
past
)
):
if
self
.
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
x
,)
x
,
attn
=
self
.
h
[
i
](
x
,
mask
)
all_hidden_states
=
all_hidden_states
+
(
x
.
view
(
*
output_shape
),)
outputs
=
h
(
x
,
mask
,
layer_past
=
layer_past
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
[
i
])
x
,
present
=
outputs
[:
2
]
presents
=
presents
+
(
present
,)
if
self
.
output_attentions
:
all_attentions
.
append
(
attn
)
all_attentions
.
append
(
outputs
[
2
]
)
x
=
self
.
layernorm
(
x
)
x
=
x
.
view
(
*
output_shape
)
if
self
.
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
x
,)
outputs
=
(
x
,
None
)
outputs
=
(
x
,
presents
)
if
self
.
output_hidden_states
:
outputs
=
outputs
+
(
all_hidden_states
,)
if
self
.
output_attentions
:
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape
=
input_shape
[:
-
1
]
+
(
-
1
,)
+
all_attentions
[
0
].
shape
[
-
2
:]
all_attentions
=
tuple
(
t
.
view
(
*
attention_output_shape
)
for
t
in
all_attentions
)
outputs
=
outputs
+
(
all_attentions
,)
return
outputs
...
...
@@ -359,13 +442,17 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self
.
_tie_or_clone_weights
(
self
.
lm_head
,
self
.
transformer
.
w
)
#self._tie_or_clone_weights(self.lm_head.bias,
# self.transformer.w.bias)
self
.
_tie_or_clone_weights
(
self
.
lm_head
,
self
.
transformer
.
w
)
def
forward
(
self
,
input_ids
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
labels
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
)
transformer_outputs
=
self
.
transformer
(
input_ids
,
past
=
past
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
head_mask
=
head_mask
)
hidden_states
=
transformer_outputs
[
0
]
lm_logits
=
self
.
lm_head
(
hidden_states
)
...
...
@@ -383,5 +470,3 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), lm_logits, presents, (all hidden_states), (attentions)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment