Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b6992b7b
Commit
b6992b7b
authored
Aug 31, 2019
by
LysandreJik
Browse files
Applied patch to OpenAI GPT, RoBERTa, TransfoL, XLM and XLNet
parent
bdb4409e
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
27 additions
and
41 deletions
+27
-41
pytorch_transformers/modeling_openai.py
pytorch_transformers/modeling_openai.py
+8
-17
pytorch_transformers/modeling_roberta.py
pytorch_transformers/modeling_roberta.py
+2
-2
pytorch_transformers/modeling_transfo_xl.py
pytorch_transformers/modeling_transfo_xl.py
+3
-6
pytorch_transformers/modeling_xlm.py
pytorch_transformers/modeling_xlm.py
+9
-8
pytorch_transformers/modeling_xlnet.py
pytorch_transformers/modeling_xlnet.py
+5
-8
No files found.
pytorch_transformers/modeling_openai.py
View file @
b6992b7b
...
...
@@ -249,14 +249,15 @@ class Attention(nn.Module):
self
.
c_proj
=
Conv1D
(
n_state
,
nx
)
self
.
attn_dropout
=
nn
.
Dropout
(
config
.
attn_pdrop
)
self
.
resid_dropout
=
nn
.
Dropout
(
config
.
resid_pdrop
)
self
.
pruned_heads
=
[]
self
.
pruned_heads
=
set
()
def
prune_heads
(
self
,
heads
):
if
len
(
heads
)
==
0
:
return
mask
=
torch
.
ones
(
self
.
n_head
,
self
.
split_size
//
self
.
n_head
)
heads
=
set
(
heads
)
-
self
.
pruned_heads
for
head
in
heads
:
head
-=
len
(
list
(
filter
(
lambda
h
:
h
<
head
,
self
.
pruned_heads
)
))
head
-=
sum
(
1
if
h
<
head
else
0
for
h
in
self
.
pruned_heads
)
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
...
...
@@ -267,7 +268,7 @@ class Attention(nn.Module):
# Update hyper params
self
.
split_size
=
(
self
.
split_size
//
self
.
n_head
)
*
(
self
.
n_head
-
len
(
heads
))
self
.
n_head
=
self
.
n_head
-
len
(
heads
)
self
.
pruned_heads
.
extend
(
heads
)
self
.
pruned_heads
=
self
.
pruned_heads
.
union
(
heads
)
def
_attn
(
self
,
q
,
k
,
v
,
head_mask
=
None
):
w
=
torch
.
matmul
(
q
,
k
)
...
...
@@ -366,10 +367,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
load_tf_weights
=
load_tf_weights_in_openai_gpt
base_model_prefix
=
"transformer"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
OpenAIGPTPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
""" Initialize the weights.
"""
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Embedding
,
Conv1D
)):
...
...
@@ -459,14 +457,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
self
.
h
=
nn
.
ModuleList
([
Block
(
config
.
n_ctx
,
config
,
scale
=
True
)
for
_
in
range
(
config
.
n_layer
)])
if
hasattr
(
config
,
"pruned_heads"
):
pruned_heads
=
config
.
pruned_heads
.
copy
().
items
()
config
.
pruned_heads
=
{}
for
layer
,
heads
in
pruned_heads
:
if
self
.
h
[
int
(
layer
)].
attn
.
n_head
==
config
.
n_head
:
self
.
prune_heads
({
int
(
layer
):
list
(
map
(
int
,
heads
))})
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
()
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
self
.
tokens_embed
=
self
.
_get_resized_embeddings
(
self
.
tokens_embed
,
new_num_tokens
)
...
...
@@ -579,7 +570,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
self
.
transformer
=
OpenAIGPTModel
(
config
)
self
.
lm_head
=
nn
.
Linear
(
config
.
n_embd
,
config
.
vocab_size
,
bias
=
False
)
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
self
.
tie_weights
()
def
tie_weights
(
self
):
...
...
@@ -686,7 +677,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
self
.
lm_head
=
nn
.
Linear
(
config
.
n_embd
,
config
.
vocab_size
,
bias
=
False
)
self
.
multiple_choice_head
=
SequenceSummary
(
config
)
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
self
.
tie_weights
()
def
tie_weights
(
self
):
...
...
pytorch_transformers/modeling_roberta.py
View file @
b6992b7b
...
...
@@ -168,7 +168,7 @@ class RobertaModel(BertModel):
super
(
RobertaModel
,
self
).
__init__
(
config
)
self
.
embeddings
=
RobertaEmbeddings
(
config
)
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
position_ids
=
None
,
head_mask
=
None
):
if
input_ids
[:,
0
].
sum
().
item
()
!=
0
:
...
...
@@ -220,7 +220,7 @@ class RobertaForMaskedLM(BertPreTrainedModel):
self
.
roberta
=
RobertaModel
(
config
)
self
.
lm_head
=
RobertaLMHead
(
config
)
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
self
.
tie_weights
()
def
tie_weights
(
self
):
...
...
pytorch_transformers/modeling_transfo_xl.py
View file @
b6992b7b
...
...
@@ -853,9 +853,6 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
load_tf_weights
=
load_tf_weights_in_transfo_xl
base_model_prefix
=
"transformer"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
TransfoXLPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
_init_weight
(
self
,
weight
):
if
self
.
config
.
init
==
'uniform'
:
nn
.
init
.
uniform_
(
weight
,
-
self
.
config
.
init_range
,
self
.
config
.
init_range
)
...
...
@@ -865,7 +862,7 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
def
_init_bias
(
self
,
bias
):
nn
.
init
.
constant_
(
bias
,
0.0
)
def
init_weights
(
self
,
m
):
def
_
init_weights
(
self
,
m
):
""" Initialize the weights.
"""
classname
=
m
.
__class__
.
__name__
...
...
@@ -1059,7 +1056,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
self
.
r_emb
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_layer
,
self
.
max_klen
,
self
.
n_head
,
self
.
d_head
))
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
return
self
.
word_emb
...
...
@@ -1306,7 +1303,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
else
:
self
.
crit
=
ProjectedAdaptiveLogSoftmax
(
config
.
n_token
,
config
.
d_embed
,
config
.
d_model
,
config
.
cutoffs
,
div_val
=
config
.
div_val
)
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
self
.
tie_weights
()
def
tie_weights
(
self
):
...
...
pytorch_transformers/modeling_xlm.py
View file @
b6992b7b
...
...
@@ -271,15 +271,16 @@ class MultiHeadAttention(nn.Module):
self
.
k_lin
=
nn
.
Linear
(
dim
,
dim
)
self
.
v_lin
=
nn
.
Linear
(
dim
,
dim
)
self
.
out_lin
=
nn
.
Linear
(
dim
,
dim
)
self
.
pruned_heads
=
[]
self
.
pruned_heads
=
set
()
def
prune_heads
(
self
,
heads
):
attention_head_size
=
self
.
dim
//
self
.
n_heads
if
len
(
heads
)
==
0
:
return
mask
=
torch
.
ones
(
self
.
n_heads
,
attention_head_size
)
heads
=
set
(
heads
)
-
self
.
pruned_heads
for
head
in
heads
:
head
-=
len
(
list
(
filter
(
lambda
h
:
h
<
head
,
self
.
pruned_heads
)
))
head
-=
sum
(
1
if
h
<
head
else
0
for
h
in
self
.
pruned_heads
)
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
...
...
@@ -291,7 +292,7 @@ class MultiHeadAttention(nn.Module):
# Update hyper params
self
.
n_heads
=
self
.
n_heads
-
len
(
heads
)
self
.
dim
=
attention_head_size
*
self
.
n_heads
self
.
pruned_heads
.
extend
(
heads
)
self
.
pruned_heads
=
self
.
pruned_heads
.
union
(
heads
)
def
forward
(
self
,
input
,
mask
,
kv
=
None
,
cache
=
None
,
head_mask
=
None
):
"""
...
...
@@ -386,7 +387,7 @@ class XLMPreTrainedModel(PreTrainedModel):
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
XLMPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
def
_
init_weights
(
self
,
module
):
""" Initialize the weights. """
if
isinstance
(
module
,
nn
.
Embedding
):
if
self
.
config
is
not
None
and
self
.
config
.
embed_init_std
is
not
None
:
...
...
@@ -569,7 +570,7 @@ class XLMModel(XLMPreTrainedModel):
if
self
.
attentions
[
int
(
layer
)].
n_heads
==
config
.
n_heads
:
self
.
prune_heads
({
int
(
layer
):
list
(
map
(
int
,
heads
))})
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
self
.
embeddings
=
self
.
_get_resized_embeddings
(
self
.
embeddings
,
new_num_tokens
)
...
...
@@ -781,7 +782,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
self
.
transformer
=
XLMModel
(
config
)
self
.
pred_layer
=
XLMPredLayer
(
config
)
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
self
.
tie_weights
()
def
tie_weights
(
self
):
...
...
@@ -843,7 +844,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
self
.
transformer
=
XLMModel
(
config
)
self
.
sequence_summary
=
SequenceSummary
(
config
)
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
def
forward
(
self
,
input_ids
,
lengths
=
None
,
position_ids
=
None
,
langs
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
cache
=
None
,
labels
=
None
,
head_mask
=
None
):
...
...
@@ -921,7 +922,7 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
self
.
transformer
=
XLMModel
(
config
)
self
.
qa_outputs
=
SQuADHead
(
config
)
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
def
forward
(
self
,
input_ids
,
lengths
=
None
,
position_ids
=
None
,
langs
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
cache
=
None
,
start_positions
=
None
,
end_positions
=
None
,
...
...
pytorch_transformers/modeling_xlnet.py
View file @
b6992b7b
...
...
@@ -586,10 +586,7 @@ class XLNetPreTrainedModel(PreTrainedModel):
load_tf_weights
=
load_tf_weights_in_xlnet
base_model_prefix
=
"transformer"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
XLNetPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
""" Initialize the weights.
"""
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Embedding
)):
...
...
@@ -736,7 +733,7 @@ class XLNetModel(XLNetPreTrainedModel):
self
.
layer
=
nn
.
ModuleList
([
XLNetLayer
(
config
)
for
_
in
range
(
config
.
n_layer
)])
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
self
.
word_embedding
=
self
.
_get_resized_embeddings
(
self
.
word_embedding
,
new_num_tokens
)
...
...
@@ -1037,7 +1034,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
self
.
transformer
=
XLNetModel
(
config
)
self
.
lm_loss
=
nn
.
Linear
(
config
.
d_model
,
config
.
n_token
,
bias
=
True
)
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
self
.
tie_weights
()
def
tie_weights
(
self
):
...
...
@@ -1114,7 +1111,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
self
.
sequence_summary
=
SequenceSummary
(
config
)
self
.
logits_proj
=
nn
.
Linear
(
config
.
d_model
,
config
.
num_labels
)
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
...
...
@@ -1216,7 +1213,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
self
.
end_logits
=
PoolerEndLogits
(
config
)
self
.
answer_class
=
PoolerAnswerClass
(
config
)
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment