Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2b56e988
Commit
2b56e988
authored
Jun 28, 2019
by
thomwolf
Browse files
standardizing API across models - XLNetForSeqClass working
parent
3a00674c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
277 additions
and
247 deletions
+277
-247
examples/run_xlnet_classifier.py
examples/run_xlnet_classifier.py
+32
-27
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+63
-74
pytorch_pretrained_bert/modeling_xlm.py
pytorch_pretrained_bert/modeling_xlm.py
+74
-58
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+108
-88
No files found.
examples/run_xlnet_classifier.py
View file @
2b56e988
...
...
@@ -67,6 +67,8 @@ def main():
help
=
"The initial learning rate for Adam."
)
parser
.
add_argument
(
"--num_train_epochs"
,
default
=
3.0
,
type
=
float
,
help
=
"Total number of training epochs to perform."
)
parser
.
add_argument
(
"--max_steps"
,
default
=-
1
,
type
=
int
,
help
=
"If > 0 limit the number of training steps to perform, you should choose only one of num_train_epochs and max_steps."
)
parser
.
add_argument
(
"--warmup_proportion"
,
default
=
0.1
,
type
=
float
,
help
=
"Proportion of training to perform linear learning rate warmup for. "
"E.g., 0.1 = 10%% of training."
)
...
...
@@ -189,8 +191,7 @@ def main():
model
=
torch
.
nn
.
DataParallel
(
model
)
global_step
=
0
nb_tr_steps
=
0
tr_loss
=
0
curr_tr_loss
,
curr_steps
=
0.
,
1
if
args
.
do_train
:
if
args
.
local_rank
in
[
-
1
,
0
]:
...
...
@@ -229,11 +230,14 @@ def main():
train_data
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_label_ids
)
if
args
.
local_rank
==
-
1
:
train_sampler
=
RandomSampler
(
train_data
)
train_sampler
=
SequentialSampler
(
train_data
)
#
RandomSampler(train_data)
else
:
train_sampler
=
DistributedSampler
(
train_data
)
train_dataloader
=
DataLoader
(
train_data
,
sampler
=
train_sampler
,
batch_size
=
args
.
train_batch_size
)
if
args
.
max_steps
>
0
:
num_train_optimization_steps
=
args
.
max_steps
else
:
num_train_optimization_steps
=
len
(
train_dataloader
)
//
args
.
gradient_accumulation_steps
*
args
.
num_train_epochs
# Prepare optimizer
...
...
@@ -275,22 +279,16 @@ def main():
logger
.
info
(
" Num steps = %d"
,
num_train_optimization_steps
)
model
.
train
()
for
_
in
trange
(
int
(
args
.
num_train_epochs
),
desc
=
"Epoch"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
]):
tr_loss
=
0
nb_tr_examples
,
nb_tr_steps
=
0
,
0
for
step
,
batch
in
enumerate
(
tqdm
(
train_dataloader
,
desc
=
"Iteration"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
])):
for
_
in
trange
(
int
(
args
.
num_train_epochs
)
if
args
.
max_steps
<=
0
else
int
(
'Inf'
),
desc
=
"Epoch"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
]):
for
step
,
batch
in
enumerate
(
tqdm
(
train_dataloader
,
desc
=
"Iteration"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
])):
batch
=
tuple
(
t
.
to
(
device
)
for
t
in
batch
)
input_ids
,
input_mask
,
segment_ids
,
label_ids
=
batch
# define a new function to compute loss values for both output_modes
logits
,
_
=
model
(
input_ids
,
token_type_ids
=
segment_ids
,
attention_mask
=
input_mask
)
if
output_mode
==
"classification"
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
,
num_labels
),
label_ids
.
view
(
-
1
))
elif
output_mode
==
"regression"
:
loss_fct
=
MSELoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
),
label_ids
.
view
(
-
1
))
loss
,
_
=
model
(
input_ids
,
token_type_ids
=
segment_ids
,
attention_mask
=
input_mask
,
labels
=
label_ids
)
if
n_gpu
>
1
:
loss
=
loss
.
mean
()
# mean() to average on multi-gpu.
...
...
@@ -302,12 +300,10 @@ def main():
else
:
loss
.
backward
()
if
args
.
clip_gradients
>
0.0
:
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
args
.
clip_gradients
)
gnorm
=
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
args
.
clip_gradients
)
tr_loss
+=
loss
.
item
()
nb_tr_examples
+=
input_ids
.
size
(
0
)
nb_tr_steps
+=
1
curr_tr_loss
+=
loss
.
item
()
curr_steps
+=
1
if
(
step
+
1
)
%
args
.
gradient_accumulation_steps
==
0
:
if
args
.
fp16
:
# modify learning rate with special warm up BERT uses
...
...
@@ -318,10 +314,19 @@ def main():
optimizer
.
step
()
optimizer
.
zero_grad
()
global_step
+=
1
if
args
.
local_rank
in
[
-
1
,
0
]
and
(
args
.
log_every
<=
0
or
(
step
+
1
)
%
args
.
log_every
==
0
):
if
not
args
.
fp16
:
tb_writer
.
add_scalar
(
'lr'
,
optimizer
.
get_lr
()[
0
],
global_step
)
tb_writer
.
add_scalar
(
'loss'
,
loss
.
item
(),
global_step
)
if
args
.
local_rank
in
[
-
1
,
0
]
and
(
args
.
log_every
<=
0
or
(
global_step
+
1
)
%
args
.
log_every
==
0
):
learning_rate
=
optimizer
.
get_lr
()[
0
]
if
not
args
.
fp16
else
lr_this_step
logger
.
info
(
"[{}] | gnorm {:.2f} lr {:8.6f} | loss {:.2f}"
.
format
(
global_step
,
gnorm
,
learning_rate
,
curr_tr_loss
/
curr_steps
))
tb_writer
.
add_scalar
(
'lr'
,
learning_rate
,
global_step
)
tb_writer
.
add_scalar
(
'loss'
,
curr_tr_loss
/
curr_steps
,
global_step
)
curr_tr_loss
,
curr_steps
=
0.
,
1
if
args
.
max_steps
>
0
and
global_step
>
args
.
max_steps
:
break
if
args
.
max_steps
>
0
and
global_step
>
args
.
max_steps
:
break
### Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
### Example:
...
...
@@ -435,7 +440,7 @@ def main():
preds
=
np
.
squeeze
(
preds
)
result
=
compute_metrics
(
task_name
,
preds
,
out_label_ids
)
loss
=
tr_loss
/
global
_step
if
args
.
do_train
else
None
loss
=
curr_
tr_loss
/
curr
_step
s
if
args
.
do_train
else
None
result
[
'eval_loss'
]
=
eval_loss
result
[
'global_step'
]
=
global_step
...
...
@@ -508,7 +513,7 @@ def main():
preds
=
np
.
argmax
(
preds
,
axis
=
1
)
result
=
compute_metrics
(
task_name
,
preds
,
out_label_ids
)
loss
=
tr_loss
/
global
_step
if
args
.
do_train
else
None
loss
=
curr_
tr_loss
/
curr
_step
s
if
args
.
do_train
else
None
result
[
'eval_loss'
]
=
eval_loss
result
[
'global_step'
]
=
global_step
...
...
pytorch_pretrained_bert/modeling.py
View file @
2b56e988
...
...
@@ -270,15 +270,13 @@ class BertEmbeddings(nn.Module):
class
BertSelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
BertSelfAttention
,
self
).
__init__
()
if
config
.
hidden_size
%
config
.
num_attention_heads
!=
0
:
raise
ValueError
(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)"
%
(
config
.
hidden_size
,
config
.
num_attention_heads
))
self
.
output_attentions
=
output_attentions
self
.
keep_multihead_output
=
keep_multihead_output
self
.
multihead_output
=
None
self
.
num_attention_heads
=
config
.
num_attention_heads
self
.
attention_head_size
=
int
(
config
.
hidden_size
/
config
.
num_attention_heads
)
...
...
@@ -329,9 +327,9 @@ class BertSelfAttention(nn.Module):
context_layer
=
context_layer
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
self
.
all_head_size
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
if
self
.
output_attentions
:
return
attention_probs
,
context_layer
return
context_layer
outputs
=
[
context_layer
,
attention_probs
]
if
self
.
output_attentions
else
[
context_layer
]
return
outputs
class
BertSelfOutput
(
nn
.
Module
):
...
...
@@ -349,11 +347,10 @@ class BertSelfOutput(nn.Module):
class
BertAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
BertAttention
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
self
=
BertSelfAttention
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
self
=
BertSelfAttention
(
config
,
output_attentions
=
output_attentions
)
self
.
output
=
BertSelfOutput
(
config
)
def
prune_heads
(
self
,
heads
):
...
...
@@ -374,13 +371,10 @@ class BertAttention(nn.Module):
self
.
self
.
all_head_size
=
self
.
self
.
attention_head_size
*
self
.
self
.
num_attention_heads
def
forward
(
self
,
input_tensor
,
attention_mask
,
head_mask
=
None
):
self_output
=
self
.
self
(
input_tensor
,
attention_mask
,
head_mask
)
if
self
.
output_attentions
:
attentions
,
self_output
=
self_output
attention_output
=
self
.
output
(
self_output
,
input_tensor
)
if
self
.
output_attentions
:
return
attentions
,
attention_output
return
attention_output
self_outputs
=
self
.
self
(
input_tensor
,
attention_mask
,
head_mask
)
attention_output
=
self
.
output
(
self_outputs
[
0
],
input_tensor
)
outputs
=
[
attention_output
]
+
self_outputs
[
1
:]
# add attentions if we output them
return
outputs
class
BertIntermediate
(
nn
.
Module
):
...
...
@@ -413,48 +407,52 @@ class BertOutput(nn.Module):
class
BertLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
BertLayer
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
attention
=
BertAttention
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
attention
=
BertAttention
(
config
,
output_attentions
=
output_attentions
)
self
.
intermediate
=
BertIntermediate
(
config
)
self
.
output
=
BertOutput
(
config
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
head_mask
=
None
):
attention_output
=
self
.
attention
(
hidden_states
,
attention_mask
,
head_mask
)
if
self
.
output_attentions
:
attentions
,
attention_output
=
attention_output
intermediate_output
=
self
.
intermediate
(
attention_output
)
attention_outputs
=
self
.
attention
(
hidden_states
,
attention_mask
,
head_mask
)
intermediate_output
=
self
.
intermediate
(
attention_outputs
[
0
])
layer_output
=
self
.
output
(
intermediate_output
,
attention_output
)
if
self
.
output_attentions
:
return
attentions
,
layer_output
return
layer_output
outputs
=
[
layer_output
]
+
attention_outputs
[
1
:]
# add attentions if we output them
return
outputs
class
BertEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
super
(
BertEncoder
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
layer
=
BertLayer
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
output_hidden_states
=
output_hidden_states
layer
=
BertLayer
(
config
,
output_attentions
=
output_attentions
)
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
num_hidden_layers
)])
def
forward
(
self
,
hidden_states
,
attention_mask
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
all_
encoder_layer
s
=
[]
def
forward
(
self
,
hidden_states
,
attention_mask
,
head_mask
=
None
):
all_
hidden_state
s
=
[]
all_attentions
=
[]
for
i
,
layer_module
in
enumerate
(
self
.
layer
):
hidden_states
=
layer_module
(
hidden_states
,
attention_mask
,
head_mask
[
i
])
if
self
.
output_hidden_states
:
all_hidden_states
.
append
(
hidden_states
)
layer_outputs
=
layer_module
(
hidden_states
,
attention_mask
,
head_mask
[
i
])
hidden_states
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
attentions
,
hidden_states
=
hidden_states
all_attentions
.
append
(
attentions
)
if
output_all_encoded_layers
:
all_encoder_layers
.
append
(
hidden_states
)
if
not
output_all_encoded_layers
:
all_encoder_layers
.
append
(
hidden_states
)
all_attentions
.
append
(
layer_outputs
[
1
])
# Add last layer
if
self
.
output_hidden_states
:
all_hidden_states
.
append
(
hidden_states
)
outputs
=
[
hidden_states
]
if
self
.
output_hidden_states
:
outputs
.
append
(
all_hidden_states
)
if
self
.
output_attentions
:
return
all_attentions
,
all_encoder_layers
return
all_encoder_layers
outputs
.
append
(
all_attentions
)
return
outputs
# outputs, (hidden states), (attentions)
class
BertPooler
(
nn
.
Module
):
...
...
@@ -617,12 +615,13 @@ class BertModel(BertPreTrainedModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
super
(
BertModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
embeddings
=
BertEmbeddings
(
config
)
self
.
encoder
=
BertEncoder
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
output_hidden_states
=
output_hidden_states
)
self
.
pooler
=
BertPooler
(
config
)
self
.
apply
(
self
.
init_weights
)
...
...
@@ -633,13 +632,7 @@ class BertModel(BertPreTrainedModel):
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
encoder
.
layer
[
layer
].
attention
.
prune_heads
(
heads
)
def
get_multihead_outputs
(
self
):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return
[
layer
.
attention
.
self
.
multihead_output
for
layer
in
self
.
encoder
.
layer
]
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
head_mask
=
None
):
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones_like
(
input_ids
)
if
token_type_ids
is
None
:
...
...
@@ -676,19 +669,14 @@ class BertModel(BertPreTrainedModel):
head_mask
=
[
None
]
*
self
.
config
.
num_hidden_layers
embedding_output
=
self
.
embeddings
(
input_ids
,
token_type_ids
)
encode
d_layer
s
=
self
.
encoder
(
embedding_output
,
encode
r_output
s
=
self
.
encoder
(
embedding_output
,
extended_attention_mask
,
output_all_encoded_layers
=
output_all_encoded_layers
,
head_mask
=
head_mask
)
if
self
.
output_attentions
:
all_attentions
,
encoded_layers
=
encoded_layers
sequence_output
=
encoded_layers
[
-
1
]
sequence_output
=
encoder_outputs
[
0
]
pooled_output
=
self
.
pooler
(
sequence_output
)
if
not
output_all_encoded_layers
:
encoded_layers
=
encoded_layers
[
-
1
]
if
self
.
output_attentions
:
return
all_attentions
,
encoded_layers
,
pooled_output
return
encoded_layers
,
pooled_output
outputs
=
[
sequence_output
,
pooled_output
]
+
encoder_outputs
[
1
:]
# add hidden_states and attentions if they are here
return
outputs
# sequence_output, pooled_output, (hidden_states), (attentions)
class
BertForPreTraining
(
BertPreTrainedModel
):
...
...
@@ -746,32 +734,33 @@ class BertForPreTraining(BertPreTrainedModel):
masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
super
(
BertForPreTraining
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
bert
=
BertModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
output_hidden_states
=
output_hidden_states
)
self
.
cls
=
BertPreTrainingHeads
(
config
,
self
.
bert
.
embeddings
.
word_embeddings
.
weight
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
masked_lm_labels
=
None
,
next_sentence_label
=
None
,
head_mask
=
None
):
outputs
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
output_all_encoded_layers
=
False
,
head_mask
=
head_mask
)
if
self
.
output_attentions
:
all_attentions
,
sequence_output
,
pooled_output
=
outputs
else
:
sequence_output
,
pooled_output
=
outputs
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
masked_lm_labels
=
None
,
next_sentence_label
=
None
,
head_mask
=
None
):
outputs
=
self
.
bert
(
input_ids
,
token_type_ids
,
attention_mask
,
head_mask
=
head_mask
)
sequence_output
,
pooled_output
=
outputs
[:
2
]
prediction_scores
,
seq_relationship_score
=
self
.
cls
(
sequence_output
,
pooled_output
)
outputs
=
[
prediction_scores
,
seq_relationship_score
]
+
outputs
[
2
:]
# add hidden states and attention if they are here
if
masked_lm_labels
is
not
None
and
next_sentence_label
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
masked_lm_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
masked_lm_labels
.
view
(
-
1
))
next_sentence_loss
=
loss_fct
(
seq_relationship_score
.
view
(
-
1
,
2
),
next_sentence_label
.
view
(
-
1
))
total_loss
=
masked_lm_loss
+
next_sentence_loss
return
total_loss
elif
self
.
output_attentions
:
return
all_attentions
,
prediction_scores
,
seq_relationship_score
return
prediction_scores
,
seq_relationship_score
outputs
=
[
total_loss
]
+
outputs
return
outputs
# (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
class
BertForMaskedLM
(
BertPreTrainedModel
):
...
...
pytorch_pretrained_bert/modeling_xlm.py
View file @
2b56e988
...
...
@@ -919,9 +919,11 @@ class XLMModel(XLMPreTrainedModel):
class
XLMModel
(
XLMPreTrainedModel
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
super
(
XLMModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
mem_len
=
config
.
mem_len
self
.
reuse_len
=
config
.
reuse_len
self
.
d_model
=
config
.
d_model
...
...
@@ -1038,8 +1040,7 @@ class XLMModel(XLMPreTrainedModel):
return
pos_emb
def
forward
(
self
,
inp_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
head_mask
=
None
):
"""
Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
...
...
@@ -1188,23 +1189,45 @@ class XLMModel(XLMPreTrainedModel):
mems
=
[
None
]
*
len
(
self
.
layer
)
hidden_states
=
[]
attentions
=
[]
for
i
,
layer_module
in
enumerate
(
self
.
layer
):
# cache new mems
new_mems
.
append
(
self
.
cache_mem
(
output_h
,
mems
[
i
]))
# Save hidden_states
if
output_g
is
None
:
hidden_states
.
append
(
output_h
)
else
:
hidden_states
.
append
((
output_h
,
output_g
))
output_h
,
output_g
=
layer_module
(
output_h
,
output_g
,
attn_mask_h
=
non_tgt_mask
,
attn_mask_g
=
attn_mask
,
r
=
pos_emb
,
seg_mat
=
seg_mat
,
mems
=
mems
[
i
],
target_mapping
=
target_mapping
,
head_mask
=
head_mask
)
# Save last hidden_state
if
output_g
is
None
:
hidden_states
.
append
(
output_h
)
else
:
hidden_states
.
append
((
output_h
,
output_g
))
# Select the right output and add dropout
output
=
self
.
dropout
(
output_g
if
output_g
is
not
None
else
output_h
)
# We transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
output
=
output
.
permute
(
1
,
0
,
2
).
contiguous
()
if
output_g
is
None
:
hidden_states
=
[
hs
.
permute
(
1
,
0
,
2
).
contiguous
()
for
hs
in
hidden_states
]
else
:
hidden_states
=
[
h
.
permute
(
1
,
0
,
2
).
contiguous
()
for
hs
in
hidden_states
for
h
in
hs
]
return
output
,
hidden_states
,
new_mems
# Build the list of outputs
outputs
=
[
output
,
new_mems
]
if
self
.
output_attentions
:
outputs
.
append
(
attentions
)
if
self
.
output_hidden_states
:
outputs
.
append
(
hidden_states
)
return
outputs
class
XLMPredLayer
(
nn
.
Module
):
...
...
@@ -1309,14 +1332,15 @@ class XLMLMHeadModel(XLMPreTrainedModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
super
(
XLMLMHeadModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
attn_type
=
config
.
attn_type
self
.
same_length
=
config
.
same_length
self
.
transformer
=
XLMModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
transformer
=
XLMModel
(
config
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
)
self
.
lm_loss
=
nn
.
Linear
(
config
.
d_model
,
config
.
n_token
,
bias
=
True
)
# Tie weights
...
...
@@ -1331,7 +1355,7 @@ class XLMLMHeadModel(XLMPreTrainedModel):
def
forward
(
self
,
inp_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
labels
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
labels
=
None
,
head_mask
=
None
):
"""
Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
...
...
@@ -1358,33 +1382,28 @@ class XLMLMHeadModel(XLMPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
"""
output
,
hidden_states
,
new_mems
=
self
.
transformer
(
inp_k
,
token_type_ids
,
input_mask
,
attention_mask
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
output_all_encoded_layers
,
head_mask
)
transformer_outputs
=
self
.
transformer
(
inp_k
,
token_type_ids
,
input_mask
,
attention_mask
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
head_mask
)
output
=
transformer_outputs
[
0
]
logits
=
self
.
lm_loss
(
output
)
outputs
=
transformer_outputs
[
1
:]
# Keep new_mems and attention/hidden states if they are here
if
labels
is
not
None
:
# Flatten the tokens
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
logits
.
view
(
-
1
,
logits
.
size
(
-
1
)),
labels
.
view
(
-
1
))
return
loss
,
new_mems
outputs
=
[
loss
]
+
outputs
outputs
=
[
logits
]
+
outputs
# if self.output_attentions:
# all_attentions, encoded_layers = encoded_layers
# sequence_output = encoded_layers[-1]
# pooled_output = self.pooler(sequence_output)
# if not output_all_encoded_layers:
# encoded_layers = encoded_layers[-1]
# if self.output_attentions:
return
logits
,
new_mems
# return all_attentions, encoded_layers, pooled_output
return
outputs
class
XLMSequenceSummary
(
nn
.
Module
):
def
__init__
(
self
,
config
,
summary_type
=
"last"
,
use_proj
=
True
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
,
summary_type
=
"last"
,
use_proj
=
True
):
super
(
XLMSequenceSummary
,
self
).
__init__
()
self
.
summary_type
=
summary_type
if
use_proj
:
...
...
@@ -1481,26 +1500,23 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
```
"""
def
__init__
(
self
,
config
,
summary_type
=
"last"
,
use_proj
=
True
,
num_labels
=
2
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
output_attentions
=
False
,
output_hidden_states
=
False
):
super
(
XLMForSequenceClassification
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
attn_type
=
config
.
attn_type
self
.
same_length
=
config
.
same_length
self
.
output_hidden_states
=
output_hidden_states
self
.
summary_type
=
summary_type
self
.
num_labels
=
num_labels
self
.
transformer
=
XLMModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
transformer
=
XLMModel
(
config
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
)
self
.
sequence_summary
=
XLMSequenceSummary
(
config
,
summary_type
=
summary_type
,
use_proj
=
use_proj
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
sequence_summary
=
XLMSequenceSummary
(
config
,
summary_type
=
summary_type
,
use_proj
=
use_proj
)
self
.
logits_proj
=
nn
.
Linear
(
config
.
d_model
,
num_labels
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
inp_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
labels
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
labels
=
None
,
head_mask
=
None
):
"""
Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
...
...
@@ -1528,13 +1544,15 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
Only used during pretraining for two-stream attention.
Set to None during finetuning.
"""
output
,
_
,
new_mems
=
self
.
transformer
(
inp_k
,
token_type_ids
,
input_mask
,
attention_mask
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
output_all_encoded_layers
,
head_mask
)
transformer_outputs
=
self
.
transformer
(
inp_k
,
token_type_ids
,
input_mask
,
attention_mask
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
head_mask
)
output
=
transformer_outputs
[
0
]
output
=
self
.
sequence_summary
(
output
)
logits
=
self
.
logits_proj
(
output
)
outputs
=
transformer_outputs
[
1
:]
# Keep new_mems and attention/hidden states if they are here
if
labels
is
not
None
:
if
self
.
num_labels
==
1
:
# We are doing regression
...
...
@@ -1543,17 +1561,11 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
else
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
return
loss
,
new_mems
outputs
=
[
loss
]
+
outputs
outputs
=
[
logits
]
+
outputs
# if self.output_attentions:
# all_attentions, encoded_layers = encoded_layers
# sequence_output = encoded_layers[-1]
# pooled_output = self.pooler(sequence_output)
# if not output_all_encoded_layers:
# encoded_layers = encoded_layers[-1]
# if self.output_attentions:
return
logits
,
new_mems
# return all_attentions, encoded_layers, pooled_output
return
outputs
class
XLMForQuestionAnswering
(
XLMPreTrainedModel
):
...
...
@@ -1612,27 +1624,30 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
super
(
XLMForQuestionAnswering
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
transformer
=
XLMModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
output_hidden_states
=
output_hidden_states
self
.
transformer
=
XLMModel
(
config
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
)
self
.
qa_outputs
=
nn
.
Linear
(
config
.
hidden_size
,
2
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
inp_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
start_positions
=
None
,
end_positions
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
output
,
_
,
new_mems
=
self
.
transformer
(
inp_k
,
token_type_ids
,
input_mask
,
attention_mask
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
output_all_encoded_layers
,
head_mask
)
start_positions
=
None
,
end_positions
=
None
,
head_mask
=
None
):
transformer_outputs
=
self
.
transformer
(
inp_k
,
token_type_ids
,
input_mask
,
attention_mask
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
head_mask
)
output
=
transformer_outputs
[
0
]
logits
=
self
.
qa_outputs
(
output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
outputs
=
transformer_outputs
[
1
:]
# Keep new_mems and attention/hidden states if they are here
if
start_positions
is
not
None
and
end_positions
is
not
None
:
# If we are on multi-GPU, split add a dimension
if
len
(
start_positions
.
size
())
>
1
:
...
...
@@ -1648,7 +1663,8 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
start_loss
=
loss_fct
(
start_logits
,
start_positions
)
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2
return
total_loss
elif
self
.
output_attentions
:
return
all_attentions
,
start_logits
,
end_logits
return
start_logits
,
end_logits
outputs
=
[
total_loss
]
+
outputs
outputs
=
[
start_logits
,
end_logits
]
+
outputs
return
outputs
pytorch_pretrained_bert/modeling_xlnet.py
View file @
2b56e988
...
...
@@ -323,16 +323,13 @@ except ImportError:
return
self
.
weight
*
x
+
self
.
bias
class
XLNetRelativeAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
):
super
(
XLNetRelativeAttention
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
if
config
.
d_model
%
config
.
n_head
!=
0
:
raise
ValueError
(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)"
%
(
config
.
d_model
,
config
.
n_head
))
self
.
output_attentions
=
output_attentions
self
.
keep_multihead_output
=
keep_multihead_output
self
.
multihead_output
=
None
self
.
n_head
=
config
.
n_head
self
.
d_head
=
config
.
d_head
...
...
@@ -368,7 +365,7 @@ class XLNetRelativeAttention(nn.Module):
return
x
def
rel_attn_core
(
self
,
q_head
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_mat
=
None
,
attn_mask
=
None
):
def
rel_attn_core
(
self
,
q_head
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_mat
=
None
,
attn_mask
=
None
,
head_mask
=
None
):
"""Core relative positional attention operations."""
# content based attention score
...
...
@@ -395,9 +392,16 @@ class XLNetRelativeAttention(nn.Module):
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
self
.
dropout
(
attn_prob
)
# Mask heads if we want to
if
head_mask
is
not
None
:
attn_prob
=
attn_prob
*
head_mask
# attention output
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
attn_prob
,
v_head_h
)
if
self
.
output_attentions
:
return
attn_vec
,
attn_prob
return
attn_vec
def
post_attention
(
self
,
h
,
attn_vec
,
residual
=
True
):
...
...
@@ -439,7 +443,10 @@ class XLNetRelativeAttention(nn.Module):
# core attention ops
attn_vec_h
=
self
.
rel_attn_core
(
q_head_h
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_mat
=
seg_mat
,
attn_mask
=
attn_mask_h
)
q_head_h
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_mat
=
seg_mat
,
attn_mask
=
attn_mask_h
,
head_mask
=
head_mask
)
if
self
.
output_attentions
:
attn_vec_h
,
attn_prob_h
=
attn_vec_h
# post processing
output_h
=
self
.
post_attention
(
h
,
attn_vec_h
)
...
...
@@ -452,14 +459,25 @@ class XLNetRelativeAttention(nn.Module):
if
target_mapping
is
not
None
:
q_head_g
=
torch
.
einsum
(
'mbnd,mlb->lbnd'
,
q_head_g
,
target_mapping
)
attn_vec_g
=
self
.
rel_attn_core
(
q_head_g
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_mat
=
seg_mat
,
attn_mask
=
attn_mask_g
)
q_head_g
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_mat
=
seg_mat
,
attn_mask
=
attn_mask_g
,
head_mask
=
head_mask
)
if
self
.
output_attentions
:
attn_vec_g
,
attn_prob_g
=
attn_vec_g
attn_vec_g
=
torch
.
einsum
(
'lbnd,mlb->mbnd'
,
attn_vec_g
,
target_mapping
)
else
:
attn_vec_g
=
self
.
rel_attn_core
(
q_head_g
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_mat
=
seg_mat
,
attn_mask
=
attn_mask_g
)
q_head_g
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_mat
=
seg_mat
,
attn_mask
=
attn_mask_g
,
head_mask
=
head_mask
)
if
self
.
output_attentions
:
attn_vec_g
,
attn_prob_g
=
attn_vec_g
# post processing
output_g
=
self
.
post_attention
(
g
,
attn_vec_g
)
if
self
.
output_attentions
:
attn_prob
=
attn_prob_h
,
attn_prob_g
else
:
###### Multi-head attention with relative positional encoding
if
mems
is
not
None
and
mems
.
dim
()
>
1
:
...
...
@@ -477,30 +495,18 @@ class XLNetRelativeAttention(nn.Module):
# core attention ops
attn_vec
=
self
.
rel_attn_core
(
q_head_h
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_mat
=
seg_mat
,
attn_mask
=
attn_mask_h
)
q_head_h
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_mat
=
seg_mat
,
attn_mask
=
attn_mask_h
,
head_mask
=
head_mask
)
if
self
.
output_attentions
:
attn_vec
,
attn_prob
=
attn_vec
# post processing
output_h
=
self
.
post_attention
(
h
,
attn_vec
)
output_g
=
None
if
self
.
output_attentions
:
return
output_h
,
output_g
,
attn_prob
# Mask heads if we want to
# if head_mask is not None:
# attention_probs = attention_probs * head_mask
# context_layer = torch.matmul(attention_probs, value_layer)
# if self.keep_multihead_output:
# self.multihead_output = context_layer
# self.multihead_output.retain_grad()
# context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
# new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
# context_layer = context_layer.view(*new_context_layer_shape)
# if self.output_attentions:
# attentions, self_output = self_output
# if self.output_attentions:
# return attentions, attention_output
return
output_h
,
output_g
class
XLNetFeedForward
(
nn
.
Module
):
...
...
@@ -510,7 +516,8 @@ class XLNetFeedForward(nn.Module):
self
.
layer_1
=
nn
.
Linear
(
config
.
d_model
,
config
.
d_inner
)
self
.
layer_2
=
nn
.
Linear
(
config
.
d_inner
,
config
.
d_model
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
if
isinstance
(
config
.
ff_activation
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
config
.
ff_activation
,
unicode
)):
if
isinstance
(
config
.
ff_activation
,
str
)
or
\
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
config
.
ff_activation
,
unicode
)):
self
.
activation_function
=
ACT2FN
[
config
.
ff_activation
]
else
:
self
.
activation_function
=
config
.
ff_activation
...
...
@@ -526,29 +533,27 @@ class XLNetFeedForward(nn.Module):
return
output
class
XLNetLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
):
super
(
XLNetLayer
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
rel_attn
=
XLNetRelativeAttention
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
rel_attn
=
XLNetRelativeAttention
(
config
,
output_attentions
=
output_attentions
)
self
.
ff
=
XLNetFeedForward
(
config
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
def
forward
(
self
,
output_h
,
output_g
,
attn_mask_h
,
attn_mask_g
,
r
,
seg_mat
,
mems
=
None
,
target_mapping
=
None
,
head_mask
=
None
):
output_h
,
output_g
=
self
.
rel_attn
(
output_h
,
output_
g
,
attn_mask_h
,
attn
_mask
_g
,
r
,
seg_mat
,
mems
=
mems
,
target_mapping
=
target_mapping
,
head_mask
=
head_mask
)
r
,
seg_mat
,
mems
=
None
,
target_mapping
=
None
,
head_mask
=
None
):
outputs
=
self
.
rel_attn
(
output_h
,
output_g
,
attn_mask_h
,
attn_mask_g
,
r
,
seg_mat
,
mems
=
mems
,
target_mapping
=
target_mappin
g
,
head_mask
=
head
_mask
)
output_h
,
output_g
=
outputs
[:
2
]
if
output_g
is
not
None
:
output_g
=
self
.
ff
(
output_g
)
output_h
=
self
.
ff
(
output_h
)
# if self.output_attentions:
# return attentions, layer_output
return
output_h
,
output_g
outputs
=
[
output_h
,
output_g
]
+
outputs
[
2
:]
# Add again attentions if there are there
return
outputs
class
XLNetPreTrainedModel
(
PreTrainedModel
):
...
...
@@ -584,9 +589,11 @@ class XLNetPreTrainedModel(PreTrainedModel):
class
XLNetModel
(
XLNetPreTrainedModel
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
super
(
XLNetModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
mem_len
=
config
.
mem_len
self
.
reuse_len
=
config
.
reuse_len
self
.
d_model
=
config
.
d_model
...
...
@@ -597,8 +604,7 @@ class XLNetModel(XLNetPreTrainedModel):
self
.
word_embedding
=
nn
.
Embedding
(
config
.
n_token
,
config
.
d_model
)
self
.
mask_emb
=
nn
.
Parameter
(
torch
.
Tensor
(
1
,
1
,
config
.
d_model
))
layer
=
XLNetLayer
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
layer
=
XLNetLayer
(
config
,
output_attentions
=
output_attentions
)
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
n_layer
)])
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
...
...
@@ -851,28 +857,39 @@ class XLNetModel(XLNetPreTrainedModel):
if
mems
is
None
:
mems
=
[
None
]
*
len
(
self
.
layer
)
attentions
=
[]
hidden_states
=
[]
for
i
,
layer_module
in
enumerate
(
self
.
layer
):
# cache new mems
new_mems
.
append
(
self
.
cache_mem
(
output_h
,
mems
[
i
]))
if
self
.
output_hidden_states
:
hidden_states
.
append
((
output_h
,
output_g
)
if
output_g
is
not
None
else
output_h
)
output_h
,
output_g
=
layer_module
(
output_h
,
output_g
,
attn_mask_h
=
non_tgt_mask
,
attn_mask_g
=
attn_mask
,
r
=
pos_emb
,
seg_mat
=
seg_mat
,
mems
=
mems
[
i
],
target_mapping
=
target_mapping
,
outputs
=
layer_module
(
output_h
,
output_g
,
attn_mask_h
=
non_tgt_mask
,
attn_mask_g
=
attn_mask
,
r
=
pos_emb
,
seg_mat
=
seg_mat
,
mems
=
mems
[
i
],
target_mapping
=
target_mapping
,
head_mask
=
head_mask
)
output_h
,
output_g
=
outputs
[:
2
]
if
self
.
output_attentions
:
attentions
.
append
(
outputs
[
2
:])
# Add last hidden state
if
self
.
output_hidden_states
:
hidden_states
.
append
((
output_h
,
output_g
)
if
output_g
is
not
None
else
output_h
)
output
=
self
.
dropout
(
output_g
if
output_g
is
not
None
else
output_h
)
# We transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
output
=
output
.
permute
(
1
,
0
,
2
).
contiguous
()
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
outputs
=
[
output
.
permute
(
1
,
0
,
2
).
contiguous
(),
new_mems
]
if
self
.
output_hidden_states
:
if
output_g
is
not
None
:
hidden_states
=
[
h
.
permute
(
1
,
0
,
2
).
contiguous
()
for
hs
in
hidden_states
for
h
in
hs
]
else
:
hidden_states
=
[
hs
.
permute
(
1
,
0
,
2
).
contiguous
()
for
hs
in
hidden_states
]
outputs
.
append
(
hidden_states
)
if
self
.
output_attentions
:
outputs
.
append
(
attentions
)
return
output
,
hidden_states
,
new_mems
return
output
s
# outputs, new_mems, (
hidden_states
)
,
(attentions)
class
XLNetLMHeadModel
(
XLNetPreTrainedModel
):
...
...
@@ -936,14 +953,16 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
super
(
XLNetLMHeadModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
attn_type
=
config
.
attn_type
self
.
same_length
=
config
.
same_length
self
.
transformer
=
XLNetModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
output_hidden_states
=
output_hidden_states
)
self
.
lm_loss
=
nn
.
Linear
(
config
.
d_model
,
config
.
n_token
,
bias
=
True
)
# Tie weights
...
...
@@ -989,27 +1008,24 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
"""
output
,
hidden_states
,
new_mem
s
=
self
.
transformer
(
inp_k
,
token_type_ids
,
input_mask
,
attention_mask
,
transformer_output
s
=
self
.
transformer
(
inp_k
,
token_type_ids
,
input_mask
,
attention_mask
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
head_mask
)
logits
=
self
.
lm_loss
(
output
)
logits
=
self
.
lm_loss
(
transformer_outputs
[
0
])
outputs
=
[
logits
]
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
if
labels
is
not
None
:
# Flatten the tokens
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
logits
.
view
(
-
1
,
logits
.
size
(
-
1
)),
labels
.
view
(
-
1
))
return
loss
,
new_mem
s
outputs
=
[
loss
]
+
output
s
# if self.output_attentions:
# all_attentions, encoded_layers = encoded_layers
# if self.output_attentions:
return
logits
,
new_mems
# return all_attentions, encoded_layers, pooled_output
return
outputs
# return (loss), logits, (mems), (hidden states), (attentions)
class
XLNetSequenceSummary
(
nn
.
Module
):
def
__init__
(
self
,
config
,
summary_type
=
"last"
,
use_proj
=
True
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
,
summary_type
=
"last"
,
use_proj
=
True
):
super
(
XLNetSequenceSummary
,
self
).
__init__
()
self
.
summary_type
=
summary_type
if
use_proj
:
...
...
@@ -1106,20 +1122,20 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
```
"""
def
__init__
(
self
,
config
,
summary_type
=
"last"
,
use_proj
=
True
,
num_labels
=
2
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
output_attentions
=
False
,
output_hidden_states
=
False
):
super
(
XLNetForSequenceClassification
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
attn_type
=
config
.
attn_type
self
.
same_length
=
config
.
same_length
self
.
summary_type
=
summary_type
self
.
num_labels
=
num_labels
self
.
transformer
=
XLNetModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
output_hidden_states
=
output_hidden_states
)
self
.
sequence_summary
=
XLNetSequenceSummary
(
config
,
summary_type
=
summary_type
,
use_proj
=
use_proj
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
sequence_summary
=
XLNetSequenceSummary
(
config
,
summary_type
=
summary_type
,
use_proj
=
use_proj
)
self
.
logits_proj
=
nn
.
Linear
(
config
.
d_model
,
num_labels
)
self
.
apply
(
self
.
init_weights
)
...
...
@@ -1153,12 +1169,15 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
Only used during pretraining for two-stream attention.
Set to None during finetuning.
"""
output
,
_
,
new_mem
s
=
self
.
transformer
(
inp_k
,
token_type_ids
,
input_mask
,
attention_mask
,
transformer_output
s
=
self
.
transformer
(
inp_k
,
token_type_ids
,
input_mask
,
attention_mask
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
head_mask
)
output
=
transformer_outputs
[
0
]
output
=
self
.
sequence_summary
(
output
)
logits
=
self
.
logits_proj
(
output
)
outputs
=
[
logits
]
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
if
labels
is
not
None
:
if
self
.
num_labels
==
1
:
# We are doing regression
...
...
@@ -1167,13 +1186,10 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
else
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
return
loss
,
new_mems
outputs
=
[
loss
]
+
outputs
return
outputs
# return (loss), logits, (mems), (hidden states), (attentions)
# if self.output_attentions:
# all_attentions, encoded_layers = encoded_layers
# if self.output_attentions:
return
logits
,
new_mems
# return all_attentions, encoded_layers, pooled_output
class
XLNetForQuestionAnswering
(
XLNetPreTrainedModel
):
"""XLNet model for Question Answering (span extraction).
...
...
@@ -1231,25 +1247,30 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
output_hidden_states
=
False
):
super
(
XLNetForQuestionAnswering
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_hidden_states
=
output_hidden_states
self
.
transformer
=
XLNetModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
output_hidden_states
=
output_hidden_states
)
self
.
qa_outputs
=
nn
.
Linear
(
config
.
hidden_size
,
2
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
inp_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
start_positions
=
None
,
end_positions
=
None
,
head_mask
=
None
):
output
,
_
,
new_mem
s
=
self
.
transformer
(
inp_k
,
token_type_ids
,
input_mask
,
attention_mask
,
transformer_output
s
=
self
.
transformer
(
inp_k
,
token_type_ids
,
input_mask
,
attention_mask
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
head_mask
)
logits
=
self
.
qa_outputs
(
output
)
logits
=
self
.
qa_outputs
(
transformer_outputs
[
0
])
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
outputs
=
[
start_logits
,
end_logits
]
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
if
start_positions
is
not
None
and
end_positions
is
not
None
:
# If we are on multi-GPU, split add a dimension
if
len
(
start_positions
.
size
())
>
1
:
...
...
@@ -1265,7 +1286,6 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
start_loss
=
loss_fct
(
start_logits
,
start_positions
)
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2
return
total_loss
elif
self
.
output_attentions
:
return
all_attentions
,
start_logits
,
end_logits
return
start_logits
,
end_logits
outputs
=
[
total_loss
]
+
outputs
return
outputs
# return (loss), logits, (mems), (hidden states), (attentions)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment