Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
15b70338
Commit
15b70338
authored
Jul 04, 2019
by
thomwolf
Browse files
adding squad model to xlnet and xlm
parent
fbe04423
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
531 additions
and
326 deletions
+531
-326
pytorch_pretrained_bert/model_utils.py
pytorch_pretrained_bert/model_utils.py
+182
-15
pytorch_pretrained_bert/modeling_xlm.py
pytorch_pretrained_bert/modeling_xlm.py
+86
-149
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+68
-26
pytorch_pretrained_bert/tests/modeling_openai_test.py
pytorch_pretrained_bert/tests/modeling_openai_test.py
+0
-4
pytorch_pretrained_bert/tests/modeling_xlm_test.py
pytorch_pretrained_bert/tests/modeling_xlm_test.py
+94
-90
pytorch_pretrained_bert/tests/modeling_xlnet_test.py
pytorch_pretrained_bert/tests/modeling_xlnet_test.py
+101
-42
No files found.
pytorch_pretrained_bert/model_utils.py
View file @
15b70338
...
...
@@ -25,7 +25,7 @@ from io import open
import
torch
from
torch
import
nn
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
,
functional
as
F
from
.file_utils
import
cached_path
...
...
@@ -301,8 +301,174 @@ class Conv1D(nn.Module):
return
x
class
SequenceSummary
(
nn
.
Module
):
class
PoolerStartLogits
(
nn
.
Module
):
""" Compute SQuAD start_logits from sequence hidden states. """
def
__init__
(
self
,
config
):
super
(
PoolerStartLogits
,
self
).
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
1
)
def
forward
(
self
,
hidden_states
,
p_mask
=
None
):
""" Args:
`p_mask`: [optional] invalid position mask such as query and special symbols (PAD, SEP, CLS)
shape [batch_size, seq_len]. 1.0 means token should be masked.
"""
x
=
self
.
dense
(
hidden_states
).
squeeze
(
-
1
)
if
p_mask
is
not
None
:
x
=
x
*
(
1
-
p_mask
)
-
1e30
*
p_mask
return
x
class
PoolerEndLogits
(
nn
.
Module
):
""" Compute SQuAD end_logits from sequence hidden states and start token hidden state.
"""
def
__init__
(
self
,
config
):
super
(
PoolerEndLogits
,
self
).
__init__
()
self
.
dense_0
=
nn
.
Linear
(
config
.
hidden_size
*
2
,
config
.
hidden_size
)
self
.
activation
=
nn
.
Tanh
()
self
.
LayerNorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
dense_1
=
nn
.
Linear
(
config
.
hidden_size
,
1
)
def
forward
(
self
,
hidden_states
,
start_states
=
None
,
start_positions
=
None
,
p_mask
=
None
):
""" Args:
One of start_states, start_positions should be not None. If both are set, start_positions overrides start_states.
`start_states`: hidden states of the first tokens for the labeled span: torch.LongTensor of shape identical to hidden_states.
`start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
`p_mask`: [optional] invalid position mask such as query and special symbols (PAD, SEP, CLS)
shape [batch_size, seq_len]. 1.0 means token should be masked.
"""
slen
,
hsz
=
hidden_states
.
shape
[
-
2
:]
assert
start_states
is
not
None
or
start_positions
is
not
None
,
"One of start_states, start_positions should be not None"
if
start_positions
is
not
None
:
start_positions
=
start_positions
[:,
None
,
None
].
expand
(
-
1
,
-
1
,
hsz
)
# shape (bsz, 1, hsz)
start_states
=
hidden_states
.
gather
(
-
2
,
start_positions
)
# shape (bsz, 1, hsz)
start_states
=
start_states
.
expand
(
-
1
,
slen
,
-
1
)
# shape (bsz, slen, hsz)
x
=
self
.
dense_0
(
torch
.
cat
([
hidden_states
,
start_states
],
dim
=-
1
))
x
=
self
.
activation
(
x
)
x
=
self
.
LayerNorm
(
x
)
x
=
self
.
dense_1
(
x
).
squeeze
(
-
1
)
if
p_mask
is
not
None
:
x
=
x
*
(
1
-
p_mask
)
-
1e30
*
p_mask
return
x
class
PoolerAnswerClass
(
nn
.
Module
):
""" Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """
def
__init__
(
self
,
config
):
super
(
PoolerAnswerClass
,
self
).
__init__
()
self
.
dense_0
=
nn
.
Linear
(
config
.
hidden_size
*
2
,
config
.
hidden_size
)
self
.
activation
=
nn
.
Tanh
()
self
.
dense_1
=
nn
.
Linear
(
config
.
hidden_size
,
1
,
bias
=
False
)
def
forward
(
self
,
hidden_states
,
start_states
=
None
,
start_positions
=
None
,
cls_index
=
None
):
""" Args:
One of start_states, start_positions should be not None. If both are set, start_positions overrides start_states.
`start_states`: hidden states of the first tokens for the labeled span: torch.LongTensor of shape identical to hidden_states.
`start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
`cls_index`: position of the CLS token: torch.LongTensor of shape [batch_size]. If None, take the last token.
# note(zhiliny): no dependency on end_feature so that we can obtain one single `cls_logits` for each sample
"""
slen
,
hsz
=
hidden_states
.
shape
[
-
2
:]
assert
start_states
is
not
None
or
start_positions
is
not
None
,
"One of start_states, start_positions should be not None"
if
start_positions
is
not
None
:
start_positions
=
start_positions
[:,
None
,
None
].
expand
(
-
1
,
-
1
,
hsz
)
# shape (bsz, 1, hsz)
start_states
=
hidden_states
.
gather
(
-
2
,
start_positions
).
squeeze
(
-
2
)
# shape (bsz, hsz)
if
cls_index
is
not
None
:
cls_index
=
cls_index
[:,
None
,
None
].
expand
(
-
1
,
-
1
,
hsz
)
# shape (bsz, 1, hsz)
cls_token_state
=
hidden_states
.
gather
(
-
2
,
cls_index
).
squeeze
(
-
2
)
# shape (bsz, hsz)
else
:
cls_token_state
=
hidden_states
[:,
-
1
,
:]
# shape (bsz, hsz)
x
=
self
.
dense_0
(
torch
.
cat
([
start_states
,
cls_token_state
],
dim
=-
1
))
x
=
self
.
activation
(
x
)
x
=
self
.
dense_1
(
x
).
squeeze
(
-
1
)
return
x
class
SQuADHead
(
nn
.
Module
):
""" A SQuAD head inspired by XLNet.
Compute
"""
def
__init__
(
self
,
config
):
super
(
SQuADHead
,
self
).
__init__
()
self
.
start_n_top
=
config
.
start_n_top
self
.
end_n_top
=
config
.
end_n_top
self
.
start_logits
=
PoolerStartLogits
(
config
)
self
.
end_logits
=
PoolerEndLogits
(
config
)
self
.
answer_class
=
PoolerAnswerClass
(
config
)
def
forward
(
self
,
hidden_states
,
start_positions
=
None
,
end_positions
=
None
,
cls_index
=
None
,
is_impossible
=
None
,
p_mask
=
None
):
""" hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer.
"""
outputs
=
()
start_logits
=
self
.
start_logits
(
hidden_states
,
p_mask
)
if
start_positions
is
not
None
and
end_positions
is
not
None
:
# If we are on multi-GPU, let's remove the dimension added by batch splitting
for
x
in
(
start_positions
,
end_positions
,
cls_index
,
is_impossible
):
if
x
is
not
None
and
x
.
dim
()
>
1
:
x
.
squeeze_
(
-
1
)
# during training, compute the end logits based on the ground truth of the start position
end_logits
=
self
.
end_logits
(
hidden_states
,
start_positions
=
start_positions
,
p_mask
=
p_mask
)
loss_fct
=
CrossEntropyLoss
()
start_loss
=
loss_fct
(
start_logits
,
start_positions
)
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2
if
cls_index
is
not
None
and
is_impossible
is
not
None
:
# Predict answerability from the representation of CLS and START
cls_logits
=
self
.
answer_class
(
hidden_states
,
start_positions
=
start_positions
,
cls_index
=
cls_index
)
loss_fct_cls
=
nn
.
BCEWithLogitsLoss
()
cls_loss
=
loss_fct_cls
(
cls_logits
,
is_impossible
)
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
total_loss
+=
cls_loss
*
0.5
outputs
=
(
total_loss
,
start_logits
,
end_logits
,
cls_logits
)
+
outputs
else
:
outputs
=
(
total_loss
,
start_logits
,
end_logits
)
+
outputs
else
:
# during inference, compute the end logits based on beam search
bsz
,
slen
,
hsz
=
hidden_states
.
size
()
start_log_probs
=
F
.
softmax
(
start_logits
,
dim
=-
1
)
# shape (bsz, slen)
start_top_log_probs
,
start_top_index
=
torch
.
topk
(
start_log_probs
,
self
.
start_n_top
,
dim
=-
1
)
# shape (bsz, start_n_top)
start_top_index
=
start_top_index
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
hsz
)
# shape (bsz, start_n_top, hsz)
start_states
=
torch
.
gather
(
hidden_states
,
-
2
,
start_top_index
)
# shape (bsz, start_n_top, hsz)
start_states
=
start_states
.
unsqueeze
(
1
).
expand
(
-
1
,
slen
,
-
1
,
-
1
)
# shape (bsz, slen, start_n_top, hsz)
hidden_states_expanded
=
hidden_states
.
unsqueeze
(
2
).
expand_as
(
start_states
)
# shape (bsz, slen, start_n_top, hsz)
p_mask
=
p_mask
.
unsqueeze
(
-
1
)
if
p_mask
is
not
None
else
None
end_logits
=
self
.
end_logits
(
hidden_states_expanded
,
start_states
=
start_states
,
p_mask
=
p_mask
)
end_log_probs
=
F
.
softmax
(
end_logits
,
dim
=
1
)
# shape (bsz, slen, start_n_top)
end_top_log_probs
,
end_top_index
=
torch
.
topk
(
end_log_probs
,
self
.
end_n_top
,
dim
=
1
)
# shape (bsz, end_n_top, start_n_top)
end_top_log_probs
=
end_top_log_probs
.
view
(
-
1
,
self
.
start_n_top
*
self
.
end_n_top
)
end_top_index
=
end_top_index
.
view
(
-
1
,
self
.
start_n_top
*
self
.
end_n_top
)
start_states
=
torch
.
einsum
(
"blh,bl->bh"
,
hidden_states
,
start_log_probs
)
cls_logits
=
self
.
answer_class
(
hidden_states
,
start_states
=
start_states
,
cls_index
=
cls_index
)
outputs
=
(
start_top_log_probs
,
start_top_index
,
end_top_log_probs
,
end_top_index
,
cls_logits
)
+
outputs
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
# or (if labels are provided) total_loss, start_logits, end_logits, (cls_logits)
return
outputs
class
SequenceSummary
(
nn
.
Module
):
""" Compute a single vector summary of a sequence hidden states according to various possibilities:
Args of the config class:
summary_type:
...
...
@@ -317,6 +483,7 @@ class SequenceSummary(nn.Module):
'tanh' => add a tanh activation to the output
None => no activation
"""
def
__init__
(
self
,
config
):
super
(
SequenceSummary
,
self
).
__init__
()
self
.
summary_type
=
config
.
summary_type
if
hasattr
(
config
,
'summary_use_proj'
)
else
'last'
...
...
pytorch_pretrained_bert/modeling_xlm.py
View file @
15b70338
...
...
@@ -35,7 +35,8 @@ from torch.nn import functional as F
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
.file_utils
import
cached_path
from
.model_utils
import
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
,
prune_linear_layer
from
.model_utils
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
,
prune_linear_layer
,
SequenceSummary
,
SQuADHead
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -67,15 +68,23 @@ class XLMConfig(PretrainedConfig):
n_langs
=
1
,
max_position_embeddings
=
512
,
embed_init_std
=
2048
**
-
0.5
,
layer_norm_eps
=
1e-12
,
init_std
=
0.02
,
summary_type
=
"last"
,
use_proj
=
True
,
bos_index
=
0
,
eos_index
=
1
,
pad_index
=
2
,
unk_index
=
3
,
mask_index
=
5
,
is_encoder
=
True
,
finetuning_task
=
None
,
num_labels
=
2
,
summary_type
=
'last'
,
summary_use_proj
=
True
,
summary_activation
=
'tanh'
,
summary_dropout
=
0.1
,
start_n_top
=
5
,
end_n_top
=
5
,
**
kwargs
):
"""Constructs XLMConfig.
...
...
@@ -140,8 +149,7 @@ class XLMConfig(PretrainedConfig):
self
.
causal
=
causal
self
.
asm
=
asm
self
.
n_langs
=
n_langs
self
.
summary_type
=
summary_type
self
.
use_proj
=
use_proj
self
.
layer_norm_eps
=
layer_norm_eps
self
.
bos_index
=
bos_index
self
.
eos_index
=
eos_index
self
.
pad_index
=
pad_index
...
...
@@ -151,6 +159,14 @@ class XLMConfig(PretrainedConfig):
self
.
max_position_embeddings
=
max_position_embeddings
self
.
embed_init_std
=
embed_init_std
self
.
init_std
=
init_std
self
.
finetuning_task
=
finetuning_task
self
.
num_labels
=
num_labels
self
.
summary_type
=
summary_type
self
.
summary_use_proj
=
summary_use_proj
self
.
summary_activation
=
summary_activation
self
.
summary_dropout
=
summary_dropout
self
.
start_n_top
=
start_n_top
self
.
end_n_top
=
end_n_top
else
:
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
...
...
@@ -172,26 +188,6 @@ class XLMConfig(PretrainedConfig):
return
self
.
n_layers
def
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
=
None
,
config
=
None
):
m
=
nn
.
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
=
padding_idx
)
if
config
is
not
None
and
config
.
embed_init_std
is
not
None
:
nn
.
init
.
normal_
(
m
.
weight
,
mean
=
0
,
std
=
config
.
embed_init_std
)
if
padding_idx
is
not
None
:
nn
.
init
.
constant_
(
m
.
weight
[
padding_idx
],
0
)
return
m
def
Linear
(
in_features
,
out_features
,
bias
=
True
,
config
=
None
):
m
=
nn
.
Linear
(
in_features
,
out_features
,
bias
)
if
config
is
not
None
and
config
.
init_std
is
not
None
:
nn
.
init
.
normal_
(
m
.
weight
,
mean
=
0
,
std
=
config
.
init_std
)
if
bias
:
nn
.
init
.
constant_
(
m
.
bias
,
0.
)
# nn.init.xavier_uniform_(m.weight)
# nn.init.constant_(m.bias, 0.)
return
m
def
create_sinusoidal_embeddings
(
n_pos
,
dim
,
out
):
position_enc
=
np
.
array
([
[
pos
/
np
.
power
(
10000
,
2
*
(
j
//
2
)
/
dim
)
for
j
in
range
(
dim
)]
...
...
@@ -244,7 +240,7 @@ class MultiHeadAttention(nn.Module):
NEW_ID
=
itertools
.
count
()
def
__init__
(
self
,
n_heads
,
dim
,
config
):
super
().
__init__
()
super
(
MultiHeadAttention
,
self
).
__init__
()
self
.
layer_id
=
next
(
MultiHeadAttention
.
NEW_ID
)
self
.
output_attentions
=
config
.
output_attentions
self
.
dim
=
dim
...
...
@@ -252,10 +248,10 @@ class MultiHeadAttention(nn.Module):
self
.
dropout
=
config
.
attention_dropout
assert
self
.
dim
%
self
.
n_heads
==
0
self
.
q_lin
=
Linear
(
dim
,
dim
,
config
=
config
)
self
.
k_lin
=
Linear
(
dim
,
dim
,
config
=
config
)
self
.
v_lin
=
Linear
(
dim
,
dim
,
config
=
config
)
self
.
out_lin
=
Linear
(
dim
,
dim
,
config
=
config
)
self
.
q_lin
=
nn
.
Linear
(
dim
,
dim
)
self
.
k_lin
=
nn
.
Linear
(
dim
,
dim
)
self
.
v_lin
=
nn
.
Linear
(
dim
,
dim
)
self
.
out_lin
=
nn
.
Linear
(
dim
,
dim
)
def
prune_heads
(
self
,
heads
):
attention_head_size
=
self
.
dim
//
self
.
n_heads
...
...
@@ -342,10 +338,10 @@ class MultiHeadAttention(nn.Module):
class
TransformerFFN
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
dim_hidden
,
out_dim
,
config
):
super
().
__init__
()
super
(
TransformerFFN
,
self
).
__init__
()
self
.
dropout
=
config
.
dropout
self
.
lin1
=
Linear
(
in_dim
,
dim_hidden
,
config
=
config
)
self
.
lin2
=
Linear
(
dim_hidden
,
out_dim
,
config
=
config
)
self
.
lin1
=
nn
.
Linear
(
in_dim
,
dim_hidden
)
self
.
lin2
=
nn
.
Linear
(
dim_hidden
,
out_dim
)
self
.
act
=
gelu
if
config
.
gelu_activation
else
F
.
relu
def
forward
(
self
,
input
):
...
...
@@ -363,17 +359,21 @@ class XLMPreTrainedModel(PreTrainedModel):
config_class
=
XLMConfig
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights
=
None
base_model_prefix
=
"
xlm
"
base_model_prefix
=
"
transformer
"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
XLMPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
""" Initialize the weights.
"""
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Embedding
)):
# Weights are initialized in module instantiation (see above)
pass
""" Initialize the weights. """
if
isinstance
(
module
,
nn
.
Embedding
):
if
self
.
config
is
not
None
and
self
.
config
.
embed_init_std
is
not
None
:
nn
.
init
.
normal_
(
module
.
weight
,
mean
=
0
,
std
=
self
.
config
.
embed_init_std
)
if
isinstance
(
module
,
nn
.
Linear
):
if
self
.
config
is
not
None
and
self
.
config
.
init_std
is
not
None
:
nn
.
init
.
normal_
(
module
.
weight
,
mean
=
0
,
std
=
self
.
config
.
init_std
)
if
hasattr
(
module
,
'bias'
)
and
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
0.
)
if
isinstance
(
module
,
nn
.
LayerNorm
):
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
...
...
@@ -471,13 +471,13 @@ class XLMModel(XLMPreTrainedModel):
assert
self
.
dim
%
self
.
n_heads
==
0
,
'transformer dim must be a multiple of n_heads'
# embeddings
self
.
position_embeddings
=
Embedding
(
config
.
max_position_embeddings
,
self
.
dim
,
config
=
config
)
self
.
position_embeddings
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
self
.
dim
)
if
config
.
sinusoidal_embeddings
:
create_sinusoidal_embeddings
(
config
.
max_position_embeddings
,
self
.
dim
,
out
=
self
.
position_embeddings
.
weight
)
if
config
.
n_langs
>
1
:
self
.
lang_embeddings
=
Embedding
(
self
.
n_langs
,
self
.
dim
,
config
=
config
)
self
.
embeddings
=
Embedding
(
self
.
n_words
,
self
.
dim
,
padding_idx
=
self
.
pad_index
,
config
=
config
)
self
.
layer_norm_emb
=
nn
.
LayerNorm
(
self
.
dim
,
eps
=
1e-12
)
self
.
lang_embeddings
=
nn
.
Embedding
(
self
.
n_langs
,
self
.
dim
)
self
.
embeddings
=
nn
.
Embedding
(
self
.
n_words
,
self
.
dim
,
padding_idx
=
self
.
pad_index
)
self
.
layer_norm_emb
=
nn
.
LayerNorm
(
self
.
dim
,
eps
=
config
.
layer_norm_eps
)
# transformer layers
self
.
attentions
=
nn
.
ModuleList
()
...
...
@@ -490,12 +490,14 @@ class XLMModel(XLMPreTrainedModel):
for
_
in
range
(
self
.
n_layers
):
self
.
attentions
.
append
(
MultiHeadAttention
(
self
.
n_heads
,
self
.
dim
,
config
=
config
))
self
.
layer_norm1
.
append
(
nn
.
LayerNorm
(
self
.
dim
,
eps
=
1e-12
))
self
.
layer_norm1
.
append
(
nn
.
LayerNorm
(
self
.
dim
,
eps
=
config
.
layer_norm_eps
))
# if self.is_decoder:
# self.layer_norm15.append(nn.LayerNorm(self.dim, eps=
1e-12
))
# self.layer_norm15.append(nn.LayerNorm(self.dim, eps=
config.layer_norm_eps
))
# self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
self
.
ffns
.
append
(
TransformerFFN
(
self
.
dim
,
self
.
hidden_dim
,
self
.
dim
,
config
=
config
))
self
.
layer_norm2
.
append
(
nn
.
LayerNorm
(
self
.
dim
,
eps
=
1e-12
))
self
.
layer_norm2
.
append
(
nn
.
LayerNorm
(
self
.
dim
,
eps
=
config
.
layer_norm_eps
))
self
.
apply
(
self
.
init_weights
)
def
_prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
...
...
@@ -636,14 +638,14 @@ class XLMPredLayer(nn.Module):
Prediction layer (cross_entropy or adaptive_softmax).
"""
def
__init__
(
self
,
config
):
super
().
__init__
()
super
(
XLMPredLayer
,
self
).
__init__
()
self
.
asm
=
config
.
asm
self
.
n_words
=
config
.
n_words
self
.
pad_index
=
config
.
pad_index
dim
=
config
.
emb_dim
if
config
.
asm
is
False
:
self
.
proj
=
Linear
(
dim
,
config
.
n_words
,
bias
=
True
)
self
.
proj
=
nn
.
Linear
(
dim
,
config
.
n_words
,
bias
=
True
)
else
:
self
.
proj
=
nn
.
AdaptiveLogSoftmaxWithLoss
(
in_features
=
dim
,
...
...
@@ -653,28 +655,24 @@ class XLMPredLayer(nn.Module):
head_bias
=
True
,
# default is False
)
def
forward
(
self
,
x
,
y
,
get_scores
=
False
):
def
forward
(
self
,
x
,
y
=
None
):
""" Compute the loss, and optionally the scores.
"""
Compute the loss, and optionally the scores.
"""
assert
(
y
==
self
.
pad_index
).
sum
().
item
()
==
0
outputs
=
()
if
self
.
asm
is
False
:
scores
=
self
.
proj
(
x
).
view
(
-
1
,
self
.
n_words
)
outputs
=
(
scores
,)
+
outputs
if
y
is
not
None
:
loss
=
F
.
cross_entropy
(
scores
,
y
,
reduction
=
'elementwise_mean'
)
outputs
=
(
loss
,)
+
outputs
else
:
scores
=
self
.
proj
.
log_prob
(
x
)
outputs
=
(
scores
,)
+
outputs
if
y
is
not
None
:
_
,
loss
=
self
.
proj
(
x
,
y
)
scores
=
self
.
proj
.
log_prob
(
x
)
if
get_scores
else
None
return
scores
,
loss
def
get_scores
(
self
,
x
):
"""
Compute scores.
"""
assert
x
.
dim
()
==
2
return
self
.
proj
.
log_prob
(
x
)
if
self
.
asm
else
self
.
proj
(
x
)
outputs
=
(
loss
,)
+
outputs
return
outputs
class
XLMWithLMHeadModel
(
XLMPreTrainedModel
):
...
...
@@ -731,6 +729,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
"""
def
__init__
(
self
,
config
):
super
(
XLMWithLMHeadModel
,
self
).
__init__
(
config
)
self
.
torchscript
=
config
.
torchscript
self
.
transformer
=
XLMModel
(
config
)
self
.
pred_layer
=
XLMPredLayer
(
config
)
...
...
@@ -741,6 +740,9 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
def
tie_weights
(
self
):
""" Make sure we are sharing the embeddings
"""
if
self
.
torchscript
:
self
.
pred_layer
.
proj
.
weight
=
nn
.
Parameter
(
self
.
transformer
.
embeddings
.
weight
.
clone
())
else
:
self
.
pred_layer
.
proj
.
weight
=
self
.
transformer
.
embeddings
.
weight
def
forward
(
self
,
input_ids
,
lengths
=
None
,
positions
=
None
,
langs
=
None
,
token_type_ids
=
None
,
...
...
@@ -775,55 +777,12 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
langs
=
langs
,
attention_mask
=
attention_mask
,
cache
=
cache
,
head_mask
=
head_mask
)
output
=
transformer_outputs
[
0
]
logits
=
self
.
pred_layer
(
output
,
labels
)
outputs
=
transformer_outputs
[
1
:]
# Keep new_mems and attention/hidden states if they are here
if
labels
is
not
None
:
# Flatten the tokens
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
logits
.
view
(
-
1
,
logits
.
size
(
-
1
)),
labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
[
logits
]
+
outputs
outputs
=
self
.
pred_layer
(
output
,
labels
)
outputs
=
outputs
+
transformer_outputs
[
1
:]
# Keep new_mems and attention/hidden states if they are here
return
outputs
class
XLMSequenceSummary
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
XLMSequenceSummary
,
self
).
__init__
()
self
.
summary_type
=
config
.
summary_type
if
config
.
use_proj
:
self
.
summary
=
nn
.
Linear
(
config
.
d_model
,
config
.
d_model
)
else
:
self
.
summary
=
None
if
config
.
summary_type
==
'attn'
:
# We should use a standard multi-head attention module with absolute positional embedding for that.
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
raise
NotImplementedError
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
self
.
activation
=
nn
.
Tanh
()
def
forward
(
self
,
hidden_states
):
""" hidden_states: float Tensor in shape [bsz, seq_len, d_model], the hidden-states of the last layer."""
if
self
.
summary_type
==
'last'
:
output
=
hidden_states
[:,
-
1
]
elif
self
.
summary_type
==
'first'
:
output
=
hidden_states
[:,
0
]
elif
self
.
summary_type
==
'mean'
:
output
=
hidden_states
.
mean
(
dim
=
1
)
elif
summary_type
==
'attn'
:
raise
NotImplementedError
output
=
self
.
summary
(
output
)
output
=
self
.
activation
(
output
)
output
=
self
.
dropout
(
output
)
return
output
class
XLMForSequenceClassification
(
XLMPreTrainedModel
):
"""XLM model ("XLM: Generalized Autoregressive Pretraining for Language Understanding").
...
...
@@ -890,15 +849,15 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
"""
def
__init__
(
self
,
config
):
super
(
XLMForSequenceClassification
,
self
).
__init__
(
config
)
self
.
num_labels
=
config
.
num_labels
self
.
transformer
=
XLMModel
(
config
)
self
.
sequence_summary
=
XLMSequenceSummary
(
config
)
self
.
logits_proj
=
nn
.
Linear
(
config
.
d_model
,
config
.
num_labels
)
self
.
sequence_summary
=
SequenceSummary
(
config
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
lengths
=
None
,
positions
=
None
,
langs
=
None
,
attention_mask
=
None
,
cache
=
None
,
labels
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input_ids
,
lengths
=
None
,
positions
=
None
,
langs
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
cache
=
None
,
labels
=
None
,
head_mask
=
None
):
"""
Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
...
...
@@ -930,10 +889,9 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
langs
=
langs
,
attention_mask
=
attention_mask
,
cache
=
cache
,
head_mask
=
head_mask
)
output
=
transformer_outputs
[
0
]
output
=
self
.
sequence_summary
(
output
)
logits
=
self
.
logits_proj
(
output
)
logits
=
self
.
sequence_summary
(
output
)
outputs
=
transformer_outputs
[
1
:]
# Keep new_mems and attention/hidden states if they are here
outputs
=
(
logits
,)
+
transformer_outputs
[
1
:]
# Keep new_mems and attention/hidden states if they are here
if
labels
is
not
None
:
if
self
.
num_labels
==
1
:
...
...
@@ -943,9 +901,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
else
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
[
logits
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
...
...
@@ -1010,41 +966,22 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
super
(
XLMForQuestionAnswering
,
self
).
__init__
(
config
)
self
.
transformer
=
XLMModel
(
config
)
self
.
qa_outputs
=
nn
.
Lin
ea
r
(
config
.
hidden_size
,
config
.
num_labels
)
self
.
qa_outputs
=
SQuADH
ea
d
(
config
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
lengths
=
None
,
positions
=
None
,
langs
=
None
,
attention_mask
=
None
,
cache
=
None
,
labels
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input_ids
,
lengths
=
None
,
positions
=
None
,
langs
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
cache
=
None
,
start_positions
=
None
,
end_positions
=
None
,
cls_index
=
None
,
is_impossible
=
None
,
p_mask
=
None
,
head_mask
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
lengths
=
lengths
,
positions
=
positions
,
token_type_ids
=
token_type_ids
,
langs
=
langs
,
attention_mask
=
attention_mask
,
cache
=
cache
,
head_mask
=
head_mask
)
output
=
transformer_outputs
[
0
]
logits
=
self
.
qa_outputs
(
output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
outputs
=
transformer_outputs
[
1
:]
# Keep new_mems and attention/hidden states if they are here
if
start_positions
is
not
None
and
end_positions
is
not
None
:
# If we are on multi-GPU, split add a dimension
if
len
(
start_positions
.
size
())
>
1
:
start_positions
=
start_positions
.
squeeze
(
-
1
)
if
len
(
end_positions
.
size
())
>
1
:
end_positions
=
end_positions
.
squeeze
(
-
1
)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index
=
start_logits
.
size
(
1
)
start_positions
.
clamp_
(
0
,
ignored_index
)
end_positions
.
clamp_
(
0
,
ignored_index
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=
ignored_index
)
start_loss
=
loss_fct
(
start_logits
,
start_positions
)
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2
outputs
=
[
total_loss
]
+
outputs
outputs
=
[
start_logits
,
end_logits
]
+
outputs
outputs
=
self
.
qa_outputs
(
output
,
start_positions
=
start_positions
,
end_positions
=
end_positions
,
cls_index
=
cls_index
,
is_impossible
=
is_impossible
,
p_mask
=
p_mask
)
outputs
=
outputs
+
transformer_outputs
[
1
:]
# Keep new_mems and attention/hidden states if they are here
return
outputs
pytorch_pretrained_bert/modeling_xlnet.py
View file @
15b70338
...
...
@@ -32,8 +32,8 @@ from torch.nn import functional as F
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
.file_utils
import
cached_path
from
.model_utils
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
,
SequenceSummary
)
from
.model_utils
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
,
SequenceSummary
,
PoolerAnswerClass
,
PoolerEndLogits
,
PoolerStartLogits
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -228,6 +228,8 @@ class XLNetConfig(PretrainedConfig):
summary_use_proj
=
True
,
summary_activation
=
'tanh'
,
summary_dropout
=
0.1
,
start_n_top
=
5
,
end_n_top
=
5
,
**
kwargs
):
"""Constructs XLNetConfig.
...
...
@@ -313,6 +315,8 @@ class XLNetConfig(PretrainedConfig):
self
.
summary_use_proj
=
summary_use_proj
self
.
summary_activation
=
summary_activation
self
.
summary_dropout
=
summary_dropout
self
.
start_n_top
=
start_n_top
self
.
end_n_top
=
end_n_top
else
:
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
...
...
@@ -1114,6 +1118,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
"""
def
__init__
(
self
,
config
):
super
(
XLNetForSequenceClassification
,
self
).
__init__
(
config
)
self
.
num_labels
=
config
.
num_labels
self
.
transformer
=
XLNetModel
(
config
)
self
.
sequence_summary
=
SequenceSummary
(
config
)
...
...
@@ -1174,7 +1179,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
class
XLNetForQuestionAnswering
(
XLNetPreTrainedModel
):
"""XLNet model for Question Answering (span extraction).
"""
XLNet model for Question Answering (span extraction).
This module is composed of the XLNet model with a linear layer on top of
the sequence output that computes start_logits and end_logits
...
...
@@ -1231,41 +1236,78 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
"""
def
__init__
(
self
,
config
):
super
(
XLNetForQuestionAnswering
,
self
).
__init__
(
config
)
self
.
start_n_top
=
config
.
start_n_top
self
.
end_n_top
=
config
.
end_n_top
self
.
transformer
=
XLNetModel
(
config
)
self
.
qa_outputs
=
nn
.
Linear
(
config
.
d_model
,
config
.
num_labels
)
self
.
start_logits
=
PoolerStartLogits
(
config
)
self
.
end_logits
=
PoolerEndLogits
(
config
)
self
.
answer_class
=
PoolerAnswerClass
(
config
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
start_positions
=
None
,
end_positions
=
None
,
head_mask
=
None
):
start_positions
=
None
,
end_positions
=
None
,
cls_index
=
None
,
is_impossible
=
None
,
p_mask
=
None
,
head_mask
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
token_type_ids
,
input_mask
,
attention_mask
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
head_mask
)
hidden_states
=
transformer_outputs
[
0
]
start_logits
=
self
.
start_logits
(
hidden_states
,
p_mask
)
logits
=
self
.
qa_
outputs
(
transformer_outputs
[
0
])
outputs
=
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
if
start_positions
is
not
None
and
end_positions
is
not
None
:
# If we are on multi-GPU, let's remove the dimension added by batch splitting
for
x
in
(
start_positions
,
end_positions
,
cls_index
,
is_impossible
):
if
x
is
not
None
and
x
.
dim
()
>
1
:
x
.
squeeze_
(
-
1
)
outputs
=
(
start_logits
,
end_logits
,)
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
# during training, compute the end logits based on the ground truth of the start position
end_logits
=
self
.
end_logits
(
hidden_states
,
start_positions
=
start_positions
,
p_mask
=
p_mask
)
if
start_positions
is
not
None
and
end_positions
is
not
None
:
# If we are on multi-GPU, split add a dimension
if
len
(
start_positions
.
size
())
>
1
:
start_positions
=
start_positions
.
squeeze
(
-
1
)
if
len
(
end_positions
.
size
())
>
1
:
end_positions
=
end_positions
.
squeeze
(
-
1
)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index
=
start_logits
.
size
(
1
)
start_positions
.
clamp_
(
0
,
ignored_index
)
end_positions
.
clamp_
(
0
,
ignored_index
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=
ignored_index
)
loss_fct
=
CrossEntropyLoss
()
start_loss
=
loss_fct
(
start_logits
,
start_positions
)
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2
outputs
=
(
total_loss
,)
+
outputs
return
outputs
# return (loss), logits, (mems), (hidden states), (attentions)
if
cls_index
is
not
None
and
is_impossible
is
not
None
:
# Predict answerability from the representation of CLS and START
cls_logits
=
self
.
answer_class
(
hidden_states
,
start_positions
=
start_positions
,
cls_index
=
cls_index
)
loss_fct_cls
=
nn
.
BCEWithLogitsLoss
()
cls_loss
=
loss_fct_cls
(
cls_logits
,
is_impossible
)
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is
# comparable to start_loss and end_loss
total_loss
+=
cls_loss
*
0.5
outputs
=
(
total_loss
,
start_logits
,
end_logits
,
cls_logits
)
+
outputs
else
:
outputs
=
(
total_loss
,
start_logits
,
end_logits
)
+
outputs
else
:
# during inference, compute the end logits based on beam search
bsz
,
slen
,
hsz
=
hidden_states
.
size
()
start_log_probs
=
F
.
softmax
(
start_logits
,
dim
=-
1
)
# shape (bsz, slen)
start_top_log_probs
,
start_top_index
=
torch
.
topk
(
start_log_probs
,
self
.
start_n_top
,
dim
=-
1
)
# shape (bsz, start_n_top)
start_top_index
=
start_top_index
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
hsz
)
# shape (bsz, start_n_top, hsz)
start_states
=
torch
.
gather
(
hidden_states
,
-
2
,
start_top_index
)
# shape (bsz, start_n_top, hsz)
start_states
=
start_states
.
unsqueeze
(
1
).
expand
(
-
1
,
slen
,
-
1
,
-
1
)
# shape (bsz, slen, start_n_top, hsz)
hidden_states_expanded
=
hidden_states
.
unsqueeze
(
2
).
expand_as
(
start_states
)
# shape (bsz, slen, start_n_top, hsz)
p_mask
=
p_mask
.
unsqueeze
(
-
1
)
if
p_mask
is
not
None
else
None
end_logits
=
self
.
end_logits
(
hidden_states_expanded
,
start_states
=
start_states
,
p_mask
=
p_mask
)
end_log_probs
=
F
.
softmax
(
end_logits
,
dim
=
1
)
# shape (bsz, slen, start_n_top)
end_top_log_probs
,
end_top_index
=
torch
.
topk
(
end_log_probs
,
self
.
end_n_top
,
dim
=
1
)
# shape (bsz, end_n_top, start_n_top)
end_top_log_probs
=
end_top_log_probs
.
view
(
-
1
,
self
.
start_n_top
*
self
.
end_n_top
)
end_top_index
=
end_top_index
.
view
(
-
1
,
self
.
start_n_top
*
self
.
end_n_top
)
start_states
=
torch
.
einsum
(
"blh,bl->bh"
,
hidden_states
,
start_log_probs
)
cls_logits
=
self
.
answer_class
(
hidden_states
,
start_states
=
start_states
,
cls_index
=
cls_index
)
outputs
=
(
start_top_log_probs
,
start_top_index
,
end_top_log_probs
,
end_top_index
,
cls_logits
)
+
outputs
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems, (hidden states), (attentions)
# or (if labels are provided) total_loss, start_logits, end_logits, (cls_logits), mems, (hidden states), (attentions)
return
outputs
pytorch_pretrained_bert/tests/modeling_openai_test.py
View file @
15b70338
...
...
@@ -16,11 +16,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
unittest
import
json
import
random
import
shutil
import
pytest
import
torch
...
...
pytorch_pretrained_bert/tests/modeling_xlm_test.py
View file @
15b70338
...
...
@@ -20,7 +20,7 @@ import unittest
import
shutil
import
pytest
from
pytorch_pretrained_bert
import
(
XLMConfig
,
XLMModel
,
XLMForQuestionAnswering
,
XLMForSequenceClassification
)
from
pytorch_pretrained_bert
import
(
XLMConfig
,
XLMModel
,
XLMWithLMHeadModel
,
XLMForQuestionAnswering
,
XLMForSequenceClassification
)
from
pytorch_pretrained_bert.modeling_xlm
import
PRETRAINED_MODEL_ARCHIVE_MAP
from
.model_tests_commons
import
(
create_and_check_commons
,
ConfigTester
,
ids_tensor
)
...
...
@@ -58,7 +58,8 @@ class XLMModelTest(unittest.TestCase):
summary_type
=
"last"
,
use_proj
=
True
,
scope
=
None
,
all_model_classes
=
(
XLMModel
,),
# , XLMForSequenceClassification, XLMForTokenClassification),
all_model_classes
=
(
XLMModel
,
XLMWithLMHeadModel
,
XLMForQuestionAnswering
,
XLMForSequenceClassification
),
# , XLMForSequenceClassification, XLMForTokenClassification),
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
...
...
@@ -93,6 +94,7 @@ class XLMModelTest(unittest.TestCase):
def
prepare_config_and_inputs
(
self
):
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
2
).
float
()
input_lengths
=
None
if
self
.
use_input_lengths
:
...
...
@@ -104,11 +106,11 @@ class XLMModelTest(unittest.TestCase):
sequence_labels
=
None
token_labels
=
None
choic
e_labels
=
None
is_impossibl
e_labels
=
None
if
self
.
use_labels
:
sequence_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choic
e_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
is_impossibl
e_labels
=
ids_tensor
([
self
.
batch_size
],
2
).
float
(
)
config
=
XLMConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
...
...
@@ -128,14 +130,14 @@ class XLMModelTest(unittest.TestCase):
summary_type
=
self
.
summary_type
,
use_proj
=
self
.
use_proj
)
return
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
choice_labels
return
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
def
check_loss_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
def
create_and_check_xlm_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_and_check_xlm_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
):
model
=
XLMModel
(
config
=
config
)
model
.
eval
()
outputs
=
model
(
input_ids
,
lengths
=
input_lengths
,
langs
=
token_type_ids
)
...
...
@@ -150,90 +152,92 @@ class XLMModelTest(unittest.TestCase):
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
# def create_and_check_xlm_for_masked_lm(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
# model = XLMForMaskedLM(config=config)
# model.eval()
# loss, prediction_scores = model(input_ids, token_type_ids, input_lengths, token_labels)
# result = {
# "loss": loss,
# "prediction_scores": prediction_scores,
# }
# self.parent.assertListEqual(
# list(result["prediction_scores"].size()),
# [self.batch_size, self.seq_length, self.vocab_size])
# self.check_loss_output(result)
# def create_and_check_xlm_for_question_answering(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
# model = XLMForQuestionAnswering(config=config)
# model.eval()
# loss, start_logits, end_logits = model(input_ids, token_type_ids, input_lengths, sequence_labels, sequence_labels)
# result = {
# "loss": loss,
# "start_logits": start_logits,
# "end_logits": end_logits,
# }
# self.parent.assertListEqual(
# list(result["start_logits"].size()),
# [self.batch_size, self.seq_length])
# self.parent.assertListEqual(
# list(result["end_logits"].size()),
# [self.batch_size, self.seq_length])
# self.check_loss_output(result)
# def create_and_check_xlm_for_sequence_classification(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
# config.num_labels = self.num_labels
# model = XLMForSequenceClassification(config)
# model.eval()
# loss, logits = model(input_ids, token_type_ids, input_lengths, sequence_labels)
# result = {
# "loss": loss,
# "logits": logits,
# }
# self.parent.assertListEqual(
# list(result["logits"].size()),
# [self.batch_size, self.num_labels])
# self.check_loss_output(result)
# def create_and_check_xlm_for_token_classification(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
# config.num_labels = self.num_labels
# model = XLMForTokenClassification(config=config)
# model.eval()
# loss, logits = model(input_ids, token_type_ids, input_lengths, token_labels)
# result = {
# "loss": loss,
# "logits": logits,
# }
# self.parent.assertListEqual(
# list(result["logits"].size()),
# [self.batch_size, self.seq_length, self.num_labels])
# self.check_loss_output(result)
# def create_and_check_xlm_for_multiple_choice(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
# config.num_choices = self.num_choices
# model = XLMForMultipleChoice(config=config)
# model.eval()
# multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
# multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
# multiple_choice_input_lengths = input_lengths.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
# loss, logits = model(multiple_choice_inputs_ids,
# multiple_choice_token_type_ids,
# multiple_choice_input_lengths,
# choice_labels)
# result = {
# "loss": loss,
# "logits": logits,
# }
# self.parent.assertListEqual(
# list(result["logits"].size()),
# [self.batch_size, self.num_choices])
# self.check_loss_output(result)
def
create_and_check_xlm_commons
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_and_check_xlm_lm_head
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
):
model
=
XLMWithLMHeadModel
(
config
)
model
.
eval
()
loss
,
logits
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
def
create_and_check_xlm_qa
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
):
model
=
XLMForQuestionAnswering
(
config
)
model
.
eval
()
outputs
=
model
(
input_ids
)
start_top_log_probs
,
start_top_index
,
end_top_log_probs
,
end_top_index
,
cls_logits
,
mems
=
outputs
outputs
=
model
(
input_ids
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
cls_index
=
sequence_labels
,
is_impossible
=
is_impossible_labels
,
p_mask
=
input_mask
)
outputs
=
model
(
input_ids
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
cls_index
=
sequence_labels
,
is_impossible
=
is_impossible_labels
)
total_loss
,
start_logits
,
end_logits
,
cls_logits
=
outputs
outputs
=
model
(
input_ids
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
)
total_loss
,
start_logits
,
end_logits
=
outputs
result
=
{
"loss"
:
total_loss
,
"start_logits"
:
start_logits
,
"end_logits"
:
end_logits
,
"cls_logits"
:
cls_logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"cls_logits"
].
size
()),
[
self
.
batch_size
])
def
create_and_check_xlm_sequence_classif
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
):
model
=
XLMForSequenceClassification
(
config
)
model
.
eval
()
(
logits
,)
=
model
(
input_ids
)
loss
,
logits
=
model
(
input_ids
,
labels
=
sequence_labels
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
type_sequence_label_size
])
def
create_and_check_xlm_commons
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
):
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'lengths'
:
input_lengths
}
create_and_check_commons
(
self
,
config
,
inputs_dict
)
...
...
pytorch_pretrained_bert/tests/modeling_xlnet_test.py
View file @
15b70338
...
...
@@ -49,6 +49,7 @@ class XLNetModelTest(unittest.TestCase):
d_inner
=
128
,
num_hidden_layers
=
5
,
max_position_embeddings
=
10
,
type_sequence_label_size
=
2
,
untie_r
=
True
,
bi_data
=
False
,
same_length
=
False
,
...
...
@@ -80,12 +81,14 @@ class XLNetModelTest(unittest.TestCase):
self
.
initializer_range
=
initializer_range
self
.
seed
=
seed
self
.
type_vocab_size
=
type_vocab_size
self
.
type_sequence_label_size
=
type_sequence_label_size
self
.
all_model_classes
=
all_model_classes
def
prepare_config_and_inputs
(
self
):
input_ids_1
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_2
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
segment_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
2
).
float
()
input_ids_q
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
+
1
],
self
.
vocab_size
)
perm_mask
=
torch
.
zeros
(
self
.
batch_size
,
self
.
seq_length
+
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
)
...
...
@@ -94,30 +97,13 @@ class XLNetModelTest(unittest.TestCase):
target_mapping
[:,
0
,
-
1
]
=
1.0
# predict last token
inp_q
=
target_mapping
[:,
0
,
:].
clone
()
# predict last token
# inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
# token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
# input_mask: float32 Tensor in shape [bsz, len], the input mask.
# 0 for real tokens and 1 for padding.
# mems: a list of float32 Tensors in shape [bsz, mem_len, hidden_size], memory
# from previous batches. The length of the list equals num_hidden_layers.
# If None, no memory is used.
# perm_mask: float32 Tensor in shape [bsz, len, len].
# If perm_mask[k, i, j] = 0, i attend to j in batch k;
# if perm_mask[k, i, j] = 1, i does not attend to j in batch k.
# If None, each position attends to all the others.
# target_mapping: float32 Tensor in shape [bsz, num_predict, len].
# If target_mapping[k, i, j] = 1, the i-th predict in batch k is
# on the j-th token.
# Only used during pretraining for partial prediction.
# Set to None during finetuning.
# inp_q: float32 Tensor in shape [bsz, len].
# 1 for tokens with losses and 0 for tokens without losses.
# Only used during pretraining for two-stream attention.
# Set to None during finetuning.
sequence_labels
=
None
lm_labels
=
None
is_impossible_labels
=
None
if
self
.
use_labels
:
lm_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
sequence_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
is_impossible_labels
=
ids_tensor
([
self
.
batch_size
],
2
).
float
()
config
=
XLNetConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
...
...
@@ -132,18 +118,23 @@ class XLNetModelTest(unittest.TestCase):
same_length
=
self
.
same_length
,
reuse_len
=
self
.
reuse_len
,
bi_data
=
self
.
bi_data
,
initializer_range
=
self
.
initializer_range
)
initializer_range
=
self
.
initializer_range
,
num_labels
=
self
.
type_sequence_label_size
)
return
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
)
return
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
)
def
set_seed
(
self
):
random
.
seed
(
self
.
seed
)
torch
.
manual_seed
(
self
.
seed
)
def
create_and_check_xlnet_base_model
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
):
def
create_and_check_xlnet_base_model
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
model
=
XLNetModel
(
config
)
model
.
eval
()
_
,
_
=
model
(
input_ids_1
,
input_mask
=
input_mask
)
_
,
_
=
model
(
input_ids_1
,
attention_mask
=
input_mask
)
_
,
_
=
model
(
input_ids_1
,
token_type_ids
=
segment_ids
)
outputs
,
mems_1
=
model
(
input_ids_1
)
...
...
@@ -159,7 +150,8 @@ class XLNetModelTest(unittest.TestCase):
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
):
def
create_and_check_xlnet_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
model
=
XLNetLMHeadModel
(
config
)
model
.
eval
()
...
...
@@ -198,7 +190,82 @@ class XLNetModelTest(unittest.TestCase):
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_commons
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
):
def
create_and_check_xlnet_qa
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
model
=
XLNetForQuestionAnswering
(
config
)
model
.
eval
()
outputs
=
model
(
input_ids_1
)
start_top_log_probs
,
start_top_index
,
end_top_log_probs
,
end_top_index
,
cls_logits
,
mems
=
outputs
outputs
=
model
(
input_ids_1
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
cls_index
=
sequence_labels
,
is_impossible
=
is_impossible_labels
,
p_mask
=
input_mask
)
outputs
=
model
(
input_ids_1
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
cls_index
=
sequence_labels
,
is_impossible
=
is_impossible_labels
)
total_loss
,
start_logits
,
end_logits
,
cls_logits
,
mems
=
outputs
outputs
=
model
(
input_ids_1
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
)
total_loss
,
start_logits
,
end_logits
,
mems
=
outputs
result
=
{
"loss"
:
total_loss
,
"start_logits"
:
start_logits
,
"end_logits"
:
end_logits
,
"cls_logits"
:
cls_logits
,
"mems"
:
mems
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"cls_logits"
].
size
()),
[
self
.
batch_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_sequence_classif
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
model
=
XLNetForSequenceClassification
(
config
)
model
.
eval
()
logits
,
mems_1
=
model
(
input_ids_1
)
loss
,
logits
,
mems_1
=
model
(
input_ids_1
,
labels
=
sequence_labels
)
result
=
{
"loss"
:
loss
,
"mems_1"
:
mems_1
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
type_sequence_label_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_commons
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
create_and_check_commons
(
self
,
config
,
inputs_dict
,
test_pruning
=
False
)
...
...
@@ -228,23 +295,15 @@ class XLNetModelTest(unittest.TestCase):
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_xlnet_commons
(
*
config_and_inputs
)
@
classmethod
def
mask_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
"""Creates a tensor with padding on the right (0.0 for )."""
if
rng
is
None
:
rng
=
random
.
Random
()
total_dims
=
1
for
dim
in
shape
:
total_dims
*=
dim
tester
.
create_and_check_xlnet_sequence_classif
(
*
config_and_inputs
)
values
=
[]
for
_
in
range
(
total_dims
):
values
.
append
(
rng
.
r
andin
t
(
0
,
vocab_size
-
1
)
)
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_xlnet_qa
(
*
config_
and
_
in
puts
)
return
torch
.
tensor
(
data
=
values
,
dtype
=
torch
.
long
).
view
(
shape
).
contiguous
()
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_xlnet_commons
(
*
config_and_inputs
)
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment