Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
dc894411
Commit
dc894411
authored
Oct 07, 2019
by
thomwolf
Browse files
update CTRL pytorch model
parent
320b7a7e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
246 additions
and
161 deletions
+246
-161
transformers/modeling_ctrl.py
transformers/modeling_ctrl.py
+246
-161
No files found.
transformers/modeling_ctrl.py
View file @
dc894411
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""PyTorch CTRL model."""
"""
PyTorch CTRL model."""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
...
@@ -27,7 +27,6 @@ from io import open
...
@@ -27,7 +27,6 @@ from io import open
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
pdb
from
torch.nn
import
CrossEntropyLoss
from
torch.nn
import
CrossEntropyLoss
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
...
@@ -46,7 +45,9 @@ def angle_defn(pos, i, d_model_size):
...
@@ -46,7 +45,9 @@ def angle_defn(pos, i, d_model_size):
def
positional_encoding
(
position
,
d_model_size
,
dtype
):
def
positional_encoding
(
position
,
d_model_size
,
dtype
):
# create the sinusoidal pattern for the positional encoding
# create the sinusoidal pattern for the positional encoding
angle_rads
=
(
angle_defn
(
torch
.
arange
(
position
,
dtype
=
dtype
).
unsqueeze
(
1
),
torch
.
arange
(
d_model_size
,
dtype
=
dtype
).
unsqueeze
(
0
),
d_model_size
))
angle_rads
=
(
angle_defn
(
torch
.
arange
(
position
,
dtype
=
dtype
).
unsqueeze
(
1
),
torch
.
arange
(
d_model_size
,
dtype
=
dtype
).
unsqueeze
(
0
),
d_model_size
))
sines
=
torch
.
sin
(
angle_rads
[:,
0
::
2
])
sines
=
torch
.
sin
(
angle_rads
[:,
0
::
2
])
cosines
=
torch
.
cos
(
angle_rads
[:,
1
::
2
])
cosines
=
torch
.
cos
(
angle_rads
[:,
1
::
2
])
...
@@ -54,7 +55,7 @@ def positional_encoding(position, d_model_size, dtype):
...
@@ -54,7 +55,7 @@ def positional_encoding(position, d_model_size, dtype):
pos_encoding
=
torch
.
cat
([
sines
,
cosines
],
dim
=-
1
).
unsqueeze
(
0
)
pos_encoding
=
torch
.
cat
([
sines
,
cosines
],
dim
=-
1
).
unsqueeze
(
0
)
return
pos_encoding
return
pos_encoding
def
scaled_dot_product_attention
(
q
,
k
,
v
,
mask
):
def
scaled_dot_product_attention
(
q
,
k
,
v
,
mask
,
attention_mask
=
None
,
head_mask
=
None
):
# calculate attention
# calculate attention
matmul_qk
=
torch
.
matmul
(
q
,
k
.
permute
(
0
,
1
,
3
,
2
))
matmul_qk
=
torch
.
matmul
(
q
,
k
.
permute
(
0
,
1
,
3
,
2
))
...
@@ -64,15 +65,25 @@ def scaled_dot_product_attention(q, k, v, mask):
...
@@ -64,15 +65,25 @@ def scaled_dot_product_attention(q, k, v, mask):
if
mask
is
not
None
:
if
mask
is
not
None
:
scaled_attention_logits
+=
(
mask
*
-
1e4
)
scaled_attention_logits
+=
(
mask
*
-
1e4
)
if
attention_mask
is
not
None
:
# Apply the attention mask
scaled_attention_logits
=
scaled_attention_logits
+
attention_mask
attention_weights
=
torch
.
softmax
(
scaled_attention_logits
,
dim
=-
1
)
attention_weights
=
torch
.
softmax
(
scaled_attention_logits
,
dim
=-
1
)
# Mask heads if we want to
if
head_mask
is
not
None
:
attention_weights
=
attention_weights
*
head_mask
output
=
torch
.
matmul
(
attention_weights
,
v
)
output
=
torch
.
matmul
(
attention_weights
,
v
)
return
output
,
attention_weights
return
output
,
attention_weights
class
MultiHeadAttention
(
torch
.
nn
.
Module
):
class
MultiHeadAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
d_model_size
,
num_heads
):
def
__init__
(
self
,
d_model_size
,
num_heads
,
output_attentions
=
False
):
super
(
MultiHeadAttention
,
self
).
__init__
()
super
(
MultiHeadAttention
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
d_model_size
=
d_model_size
self
.
d_model_size
=
d_model_size
...
@@ -88,7 +99,7 @@ class MultiHeadAttention(torch.nn.Module):
...
@@ -88,7 +99,7 @@ class MultiHeadAttention(torch.nn.Module):
x
=
x
.
reshape
(
batch_size
,
-
1
,
self
.
num_heads
,
self
.
depth
)
x
=
x
.
reshape
(
batch_size
,
-
1
,
self
.
num_heads
,
self
.
depth
)
return
x
.
permute
([
0
,
2
,
1
,
3
])
return
x
.
permute
([
0
,
2
,
1
,
3
])
def
forward
(
self
,
v
,
k
,
q
,
mask
):
def
forward
(
self
,
v
,
k
,
q
,
mask
,
layer_past
=
None
,
attention_mask
=
None
,
head_mask
=
None
):
batch_size
=
q
.
shape
[
0
]
batch_size
=
q
.
shape
[
0
]
q
=
self
.
Wq
(
q
)
q
=
self
.
Wq
(
q
)
...
@@ -98,7 +109,13 @@ class MultiHeadAttention(torch.nn.Module):
...
@@ -98,7 +109,13 @@ class MultiHeadAttention(torch.nn.Module):
q
=
self
.
split_into_heads
(
q
,
batch_size
)
q
=
self
.
split_into_heads
(
q
,
batch_size
)
k
=
self
.
split_into_heads
(
k
,
batch_size
)
k
=
self
.
split_into_heads
(
k
,
batch_size
)
v
=
self
.
split_into_heads
(
v
,
batch_size
)
v
=
self
.
split_into_heads
(
v
,
batch_size
)
output
=
scaled_dot_product_attention
(
q
,
k
,
v
,
mask
)
if
layer_past
is
not
None
:
past_key
,
past_value
=
layer_past
[
0
].
transpose
(
-
2
,
-
1
),
layer_past
[
1
]
# transpose back cf below
k
=
torch
.
cat
((
past_key
,
k
),
dim
=-
1
)
v
=
torch
.
cat
((
past_value
,
v
),
dim
=-
2
)
present
=
torch
.
stack
((
k
.
transpose
(
-
2
,
-
1
),
v
))
# transpose to have same shapes for stacking
output
=
scaled_dot_product_attention
(
q
,
k
,
v
,
mask
,
attention_mask
,
head_mask
,
output_attentions
)
scaled_attention
=
output
[
0
].
permute
([
0
,
2
,
1
,
3
])
scaled_attention
=
output
[
0
].
permute
([
0
,
2
,
1
,
3
])
attn
=
output
[
1
]
attn
=
output
[
1
]
original_size_attention
=
scaled_attention
.
reshape
(
batch_size
,
-
1
,
self
.
d_model_size
)
original_size_attention
=
scaled_attention
.
reshape
(
batch_size
,
-
1
,
self
.
d_model_size
)
...
@@ -109,14 +126,16 @@ class MultiHeadAttention(torch.nn.Module):
...
@@ -109,14 +126,16 @@ class MultiHeadAttention(torch.nn.Module):
def
point_wise_feed_forward_network
(
d_model_size
,
dff
):
def
point_wise_feed_forward_network
(
d_model_size
,
dff
):
return
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
d_model_size
,
dff
),
torch
.
nn
.
ReLU
(),
torch
.
nn
.
Linear
(
dff
,
d_model_size
))
return
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
d_model_size
,
dff
),
torch
.
nn
.
ReLU
(),
torch
.
nn
.
Linear
(
dff
,
d_model_size
))
class
EncoderLayer
(
torch
.
nn
.
Module
):
class
EncoderLayer
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
d_model_size
,
num_heads
,
dff
,
rate
=
0.1
):
def
__init__
(
self
,
d_model_size
,
num_heads
,
dff
,
rate
=
0.1
,
output_attentions
=
False
):
super
(
EncoderLayer
,
self
).
__init__
()
super
(
EncoderLayer
,
self
).
__init__
()
self
.
multi_head_attention
=
MultiHeadAttention
(
d_model_size
,
num_heads
)
self
.
multi_head_attention
=
MultiHeadAttention
(
d_model_size
,
num_heads
,
output_attentions
)
self
.
ffn
=
point_wise_feed_forward_network
(
d_model_size
,
dff
)
self
.
ffn
=
point_wise_feed_forward_network
(
d_model_size
,
dff
)
self
.
layernorm1
=
torch
.
nn
.
LayerNorm
(
d_model_size
,
eps
=
1e-6
)
self
.
layernorm1
=
torch
.
nn
.
LayerNorm
(
d_model_size
,
eps
=
1e-6
)
...
@@ -125,9 +144,12 @@ class EncoderLayer(torch.nn.Module):
...
@@ -125,9 +144,12 @@ class EncoderLayer(torch.nn.Module):
self
.
dropout1
=
torch
.
nn
.
Dropout
(
rate
)
self
.
dropout1
=
torch
.
nn
.
Dropout
(
rate
)
self
.
dropout2
=
torch
.
nn
.
Dropout
(
rate
)
self
.
dropout2
=
torch
.
nn
.
Dropout
(
rate
)
def
forward
(
self
,
x
,
mask
):
def
forward
(
self
,
x
,
mask
,
layer_past
=
None
,
attention_mask
=
None
,
head_mask
=
None
):
normed
=
self
.
layernorm1
(
x
)
normed
=
self
.
layernorm1
(
x
)
attn_output
,
attn
=
self
.
multi_head_attention
(
normed
,
normed
,
normed
,
mask
)
attn_output
,
attn
=
self
.
multi_head_attention
(
normed
,
normed
,
normed
,
mask
,
layer_past
=
layer_past
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
)
attn_output
=
self
.
dropout1
(
attn_output
)
attn_output
=
self
.
dropout1
(
attn_output
)
out1
=
x
+
attn_output
out1
=
x
+
attn_output
...
@@ -147,9 +169,6 @@ class CTRLPreTrainedModel(PreTrainedModel):
...
@@ -147,9 +169,6 @@ class CTRLPreTrainedModel(PreTrainedModel):
pretrained_model_archive_map
=
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map
=
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix
=
"transformer"
base_model_prefix
=
"transformer"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
CTRLPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
_init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
""" Initialize the weights.
""" Initialize the weights.
"""
"""
...
@@ -256,7 +275,11 @@ class CTRLModel(CTRLPreTrainedModel):
...
@@ -256,7 +275,11 @@ class CTRLModel(CTRLPreTrainedModel):
self
.
dropout
=
nn
.
Dropout
(
config
.
embd_pdrop
)
self
.
dropout
=
nn
.
Dropout
(
config
.
embd_pdrop
)
self
.
h
=
nn
.
ModuleList
([
EncoderLayer
(
config
.
n_embd
,
config
.
n_head
,
config
.
dff
,
config
.
resid_pdrop
)
for
_
in
range
(
config
.
n_layer
)])
self
.
h
=
nn
.
ModuleList
([
EncoderLayer
(
config
.
n_embd
,
config
.
n_head
,
config
.
dff
,
config
.
resid_pdrop
,
config
.
output_attentions
)
for
_
in
range
(
config
.
n_layer
)])
self
.
layernorm
=
nn
.
LayerNorm
(
config
.
n_embd
,
eps
=
config
.
layer_norm_epsilon
)
self
.
layernorm
=
nn
.
LayerNorm
(
config
.
n_embd
,
eps
=
config
.
layer_norm_epsilon
)
self
.
init_weights
()
self
.
init_weights
()
...
@@ -272,8 +295,54 @@ class CTRLModel(CTRLPreTrainedModel):
...
@@ -272,8 +295,54 @@ class CTRLModel(CTRLPreTrainedModel):
for
layer
,
heads
in
heads_to_prune
.
items
():
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
h
[
layer
].
attn
.
prune_heads
(
heads
)
self
.
h
[
layer
].
attn
.
prune_heads
(
heads
)
def
forward
(
self
,
input_ids
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
def
forward
(
self
,
input_ids
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
):
labels
=
None
):
input_shape
=
input_ids
.
size
()
input_ids
=
input_ids
.
view
(
-
1
,
input_shape
[
-
1
])
if
token_type_ids
is
not
None
:
token_type_ids
=
token_type_ids
.
view
(
-
1
,
input_shape
[
-
1
])
if
position_ids
is
not
None
:
position_ids
=
position_ids
.
view
(
-
1
,
input_shape
[
-
1
])
if
past
is
None
:
past_length
=
0
past
=
[
None
]
*
len
(
self
.
h
)
else
:
past_length
=
past
[
0
][
0
].
size
(
-
2
)
if
position_ids
is
None
:
position_ids
=
torch
.
arange
(
past_length
,
input_ids
.
size
(
-
1
)
+
past_length
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
input_ids
)
# Attention mask.
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
.
view
(
-
1
,
input_shape
[
-
1
])
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask
=
attention_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask
=
attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
attention_mask
=
(
1.0
-
attention_mask
)
*
-
10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
self
.
config
.
n_layer
embedded
=
self
.
w
(
input_ids
)
embedded
=
self
.
w
(
input_ids
)
x
=
embedded
.
unsqueeze
(
0
)
if
len
(
input_ids
.
shape
)
<
2
else
embedded
x
=
embedded
.
unsqueeze
(
0
)
if
len
(
input_ids
.
shape
)
<
2
else
embedded
...
@@ -282,26 +351,40 @@ class CTRLModel(CTRLPreTrainedModel):
...
@@ -282,26 +351,40 @@ class CTRLModel(CTRLPreTrainedModel):
x
*=
np
.
sqrt
(
self
.
d_model_size
)
x
*=
np
.
sqrt
(
self
.
d_model_size
)
x
+=
self
.
pos_encoding
[:,
:
seq_len
,
:].
to
(
x
.
device
)
x
+=
self
.
pos_encoding
[:,
position_ids
,
:].
to
(
x
.
device
)
x
=
self
.
dropout
(
x
)
x
=
self
.
dropout
(
x
)
output_shape
=
input_shape
+
(
x
.
size
(
-
1
),)
presents
=
()
all_hidden_states
=
()
all_hidden_states
=
()
all_attentions
=
[]
all_attentions
=
[]
for
i
in
range
(
self
.
num_layers
):
for
i
,
(
h
,
layer_past
)
in
enumerate
(
zip
(
self
.
h
,
past
)
):
if
self
.
output_hidden_states
:
if
self
.
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
x
,)
all_hidden_states
=
all_hidden_states
+
(
x
.
view
(
*
output_shape
),)
x
,
attn
=
self
.
h
[
i
](
x
,
mask
)
outputs
=
h
(
x
,
mask
,
layer_past
=
layer_past
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
[
i
])
x
,
present
=
outputs
[:
2
]
presents
=
presents
+
(
present
,)
if
self
.
output_attentions
:
if
self
.
output_attentions
:
all_attentions
.
append
(
attn
)
all_attentions
.
append
(
outputs
[
2
]
)
x
=
self
.
layernorm
(
x
)
x
=
self
.
layernorm
(
x
)
x
=
x
.
view
(
*
output_shape
)
if
self
.
output_hidden_states
:
if
self
.
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
x
,)
all_hidden_states
=
all_hidden_states
+
(
x
,)
outputs
=
(
x
,
None
)
outputs
=
(
x
,
presents
)
if
self
.
output_hidden_states
:
if
self
.
output_hidden_states
:
outputs
=
outputs
+
(
all_hidden_states
,)
outputs
=
outputs
+
(
all_hidden_states
,)
if
self
.
output_attentions
:
if
self
.
output_attentions
:
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape
=
input_shape
[:
-
1
]
+
(
-
1
,)
+
all_attentions
[
0
].
shape
[
-
2
:]
all_attentions
=
tuple
(
t
.
view
(
*
attention_output_shape
)
for
t
in
all_attentions
)
outputs
=
outputs
+
(
all_attentions
,)
outputs
=
outputs
+
(
all_attentions
,)
return
outputs
return
outputs
...
@@ -359,13 +442,17 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
...
@@ -359,13 +442,17 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
""" Make sure we are sharing the input and output embeddings.
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
"""
self
.
_tie_or_clone_weights
(
self
.
lm_head
,
self
.
_tie_or_clone_weights
(
self
.
lm_head
,
self
.
transformer
.
w
)
self
.
transformer
.
w
)
#self._tie_or_clone_weights(self.lm_head.bias,
# self.transformer.w.bias)
def
forward
(
self
,
input_ids
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
def
forward
(
self
,
input_ids
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
labels
=
None
):
labels
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
)
transformer_outputs
=
self
.
transformer
(
input_ids
,
past
=
past
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
head_mask
=
head_mask
)
hidden_states
=
transformer_outputs
[
0
]
hidden_states
=
transformer_outputs
[
0
]
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
...
@@ -383,5 +470,3 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
...
@@ -383,5 +470,3 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
outputs
=
(
loss
,)
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), lm_logits, presents, (all hidden_states), (attentions)
return
outputs
# (loss), lm_logits, presents, (all hidden_states), (attentions)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment