Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
15b70338
Commit
15b70338
authored
Jul 04, 2019
by
thomwolf
Browse files
adding squad model to xlnet and xlm
parent
fbe04423
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
531 additions
and
326 deletions
+531
-326
pytorch_pretrained_bert/model_utils.py
pytorch_pretrained_bert/model_utils.py
+182
-15
pytorch_pretrained_bert/modeling_xlm.py
pytorch_pretrained_bert/modeling_xlm.py
+86
-149
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+68
-26
pytorch_pretrained_bert/tests/modeling_openai_test.py
pytorch_pretrained_bert/tests/modeling_openai_test.py
+0
-4
pytorch_pretrained_bert/tests/modeling_xlm_test.py
pytorch_pretrained_bert/tests/modeling_xlm_test.py
+94
-90
pytorch_pretrained_bert/tests/modeling_xlnet_test.py
pytorch_pretrained_bert/tests/modeling_xlnet_test.py
+101
-42
No files found.
pytorch_pretrained_bert/model_utils.py
View file @
15b70338
...
@@ -25,7 +25,7 @@ from io import open
...
@@ -25,7 +25,7 @@ from io import open
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
,
functional
as
F
from
.file_utils
import
cached_path
from
.file_utils
import
cached_path
...
@@ -301,22 +301,189 @@ class Conv1D(nn.Module):
...
@@ -301,22 +301,189 @@ class Conv1D(nn.Module):
return
x
return
x
class
SequenceSummary
(
nn
.
Module
):
class
PoolerStartLogits
(
nn
.
Module
):
""" Compute SQuAD start_logits from sequence hidden states. """
def
__init__
(
self
,
config
):
super
(
PoolerStartLogits
,
self
).
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
1
)
def
forward
(
self
,
hidden_states
,
p_mask
=
None
):
""" Args:
`p_mask`: [optional] invalid position mask such as query and special symbols (PAD, SEP, CLS)
shape [batch_size, seq_len]. 1.0 means token should be masked.
"""
x
=
self
.
dense
(
hidden_states
).
squeeze
(
-
1
)
if
p_mask
is
not
None
:
x
=
x
*
(
1
-
p_mask
)
-
1e30
*
p_mask
return
x
class
PoolerEndLogits
(
nn
.
Module
):
""" Compute SQuAD end_logits from sequence hidden states and start token hidden state.
"""
def
__init__
(
self
,
config
):
super
(
PoolerEndLogits
,
self
).
__init__
()
self
.
dense_0
=
nn
.
Linear
(
config
.
hidden_size
*
2
,
config
.
hidden_size
)
self
.
activation
=
nn
.
Tanh
()
self
.
LayerNorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
dense_1
=
nn
.
Linear
(
config
.
hidden_size
,
1
)
def
forward
(
self
,
hidden_states
,
start_states
=
None
,
start_positions
=
None
,
p_mask
=
None
):
""" Args:
One of start_states, start_positions should be not None. If both are set, start_positions overrides start_states.
`start_states`: hidden states of the first tokens for the labeled span: torch.LongTensor of shape identical to hidden_states.
`start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
`p_mask`: [optional] invalid position mask such as query and special symbols (PAD, SEP, CLS)
shape [batch_size, seq_len]. 1.0 means token should be masked.
"""
slen
,
hsz
=
hidden_states
.
shape
[
-
2
:]
assert
start_states
is
not
None
or
start_positions
is
not
None
,
"One of start_states, start_positions should be not None"
if
start_positions
is
not
None
:
start_positions
=
start_positions
[:,
None
,
None
].
expand
(
-
1
,
-
1
,
hsz
)
# shape (bsz, 1, hsz)
start_states
=
hidden_states
.
gather
(
-
2
,
start_positions
)
# shape (bsz, 1, hsz)
start_states
=
start_states
.
expand
(
-
1
,
slen
,
-
1
)
# shape (bsz, slen, hsz)
x
=
self
.
dense_0
(
torch
.
cat
([
hidden_states
,
start_states
],
dim
=-
1
))
x
=
self
.
activation
(
x
)
x
=
self
.
LayerNorm
(
x
)
x
=
self
.
dense_1
(
x
).
squeeze
(
-
1
)
if
p_mask
is
not
None
:
x
=
x
*
(
1
-
p_mask
)
-
1e30
*
p_mask
return
x
class
PoolerAnswerClass
(
nn
.
Module
):
""" Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """
def
__init__
(
self
,
config
):
super
(
PoolerAnswerClass
,
self
).
__init__
()
self
.
dense_0
=
nn
.
Linear
(
config
.
hidden_size
*
2
,
config
.
hidden_size
)
self
.
activation
=
nn
.
Tanh
()
self
.
dense_1
=
nn
.
Linear
(
config
.
hidden_size
,
1
,
bias
=
False
)
def
forward
(
self
,
hidden_states
,
start_states
=
None
,
start_positions
=
None
,
cls_index
=
None
):
""" Args:
One of start_states, start_positions should be not None. If both are set, start_positions overrides start_states.
`start_states`: hidden states of the first tokens for the labeled span: torch.LongTensor of shape identical to hidden_states.
`start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
`cls_index`: position of the CLS token: torch.LongTensor of shape [batch_size]. If None, take the last token.
# note(zhiliny): no dependency on end_feature so that we can obtain one single `cls_logits` for each sample
"""
slen
,
hsz
=
hidden_states
.
shape
[
-
2
:]
assert
start_states
is
not
None
or
start_positions
is
not
None
,
"One of start_states, start_positions should be not None"
if
start_positions
is
not
None
:
start_positions
=
start_positions
[:,
None
,
None
].
expand
(
-
1
,
-
1
,
hsz
)
# shape (bsz, 1, hsz)
start_states
=
hidden_states
.
gather
(
-
2
,
start_positions
).
squeeze
(
-
2
)
# shape (bsz, hsz)
if
cls_index
is
not
None
:
cls_index
=
cls_index
[:,
None
,
None
].
expand
(
-
1
,
-
1
,
hsz
)
# shape (bsz, 1, hsz)
cls_token_state
=
hidden_states
.
gather
(
-
2
,
cls_index
).
squeeze
(
-
2
)
# shape (bsz, hsz)
else
:
cls_token_state
=
hidden_states
[:,
-
1
,
:]
# shape (bsz, hsz)
x
=
self
.
dense_0
(
torch
.
cat
([
start_states
,
cls_token_state
],
dim
=-
1
))
x
=
self
.
activation
(
x
)
x
=
self
.
dense_1
(
x
).
squeeze
(
-
1
)
return
x
class
SQuADHead
(
nn
.
Module
):
""" A SQuAD head inspired by XLNet.
Compute
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
""" Compute a single vector summary of a sequence hidden states according to various possibilities:
super
(
SQuADHead
,
self
).
__init__
()
Args of the config class:
self
.
start_n_top
=
config
.
start_n_top
summary_type:
self
.
end_n_top
=
config
.
end_n_top
- 'last' => [default] take the last token hidden state (like XLNet)
- 'first' => take the first token hidden state (like Bert)
self
.
start_logits
=
PoolerStartLogits
(
config
)
- 'mean' => take the mean of all tokens hidden states
self
.
end_logits
=
PoolerEndLogits
(
config
)
- 'token_ids' => supply a Tensor of classification token indices (GPT/GPT-2)
self
.
answer_class
=
PoolerAnswerClass
(
config
)
- 'attn' => Not implemented now, use multi-head attention
summary_use_proj: Add a projection after the vector extraction
def
forward
(
self
,
hidden_states
,
start_positions
=
None
,
end_positions
=
None
,
summary_num_classes: If > 0: the projection outputs to n classes (otherwise to hidden_size)
cls_index
=
None
,
is_impossible
=
None
,
p_mask
=
None
):
summary_activation:
""" hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer.
'tanh' => add a tanh activation to the output
None => no activation
"""
"""
outputs
=
()
start_logits
=
self
.
start_logits
(
hidden_states
,
p_mask
)
if
start_positions
is
not
None
and
end_positions
is
not
None
:
# If we are on multi-GPU, let's remove the dimension added by batch splitting
for
x
in
(
start_positions
,
end_positions
,
cls_index
,
is_impossible
):
if
x
is
not
None
and
x
.
dim
()
>
1
:
x
.
squeeze_
(
-
1
)
# during training, compute the end logits based on the ground truth of the start position
end_logits
=
self
.
end_logits
(
hidden_states
,
start_positions
=
start_positions
,
p_mask
=
p_mask
)
loss_fct
=
CrossEntropyLoss
()
start_loss
=
loss_fct
(
start_logits
,
start_positions
)
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2
if
cls_index
is
not
None
and
is_impossible
is
not
None
:
# Predict answerability from the representation of CLS and START
cls_logits
=
self
.
answer_class
(
hidden_states
,
start_positions
=
start_positions
,
cls_index
=
cls_index
)
loss_fct_cls
=
nn
.
BCEWithLogitsLoss
()
cls_loss
=
loss_fct_cls
(
cls_logits
,
is_impossible
)
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
total_loss
+=
cls_loss
*
0.5
outputs
=
(
total_loss
,
start_logits
,
end_logits
,
cls_logits
)
+
outputs
else
:
outputs
=
(
total_loss
,
start_logits
,
end_logits
)
+
outputs
else
:
# during inference, compute the end logits based on beam search
bsz
,
slen
,
hsz
=
hidden_states
.
size
()
start_log_probs
=
F
.
softmax
(
start_logits
,
dim
=-
1
)
# shape (bsz, slen)
start_top_log_probs
,
start_top_index
=
torch
.
topk
(
start_log_probs
,
self
.
start_n_top
,
dim
=-
1
)
# shape (bsz, start_n_top)
start_top_index
=
start_top_index
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
hsz
)
# shape (bsz, start_n_top, hsz)
start_states
=
torch
.
gather
(
hidden_states
,
-
2
,
start_top_index
)
# shape (bsz, start_n_top, hsz)
start_states
=
start_states
.
unsqueeze
(
1
).
expand
(
-
1
,
slen
,
-
1
,
-
1
)
# shape (bsz, slen, start_n_top, hsz)
hidden_states_expanded
=
hidden_states
.
unsqueeze
(
2
).
expand_as
(
start_states
)
# shape (bsz, slen, start_n_top, hsz)
p_mask
=
p_mask
.
unsqueeze
(
-
1
)
if
p_mask
is
not
None
else
None
end_logits
=
self
.
end_logits
(
hidden_states_expanded
,
start_states
=
start_states
,
p_mask
=
p_mask
)
end_log_probs
=
F
.
softmax
(
end_logits
,
dim
=
1
)
# shape (bsz, slen, start_n_top)
end_top_log_probs
,
end_top_index
=
torch
.
topk
(
end_log_probs
,
self
.
end_n_top
,
dim
=
1
)
# shape (bsz, end_n_top, start_n_top)
end_top_log_probs
=
end_top_log_probs
.
view
(
-
1
,
self
.
start_n_top
*
self
.
end_n_top
)
end_top_index
=
end_top_index
.
view
(
-
1
,
self
.
start_n_top
*
self
.
end_n_top
)
start_states
=
torch
.
einsum
(
"blh,bl->bh"
,
hidden_states
,
start_log_probs
)
cls_logits
=
self
.
answer_class
(
hidden_states
,
start_states
=
start_states
,
cls_index
=
cls_index
)
outputs
=
(
start_top_log_probs
,
start_top_index
,
end_top_log_probs
,
end_top_index
,
cls_logits
)
+
outputs
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
# or (if labels are provided) total_loss, start_logits, end_logits, (cls_logits)
return
outputs
class
SequenceSummary
(
nn
.
Module
):
""" Compute a single vector summary of a sequence hidden states according to various possibilities:
Args of the config class:
summary_type:
- 'last' => [default] take the last token hidden state (like XLNet)
- 'first' => take the first token hidden state (like Bert)
- 'mean' => take the mean of all tokens hidden states
- 'token_ids' => supply a Tensor of classification token indices (GPT/GPT-2)
- 'attn' => Not implemented now, use multi-head attention
summary_use_proj: Add a projection after the vector extraction
summary_num_classes: If > 0: the projection outputs to n classes (otherwise to hidden_size)
summary_activation:
'tanh' => add a tanh activation to the output
None => no activation
"""
def
__init__
(
self
,
config
):
super
(
SequenceSummary
,
self
).
__init__
()
super
(
SequenceSummary
,
self
).
__init__
()
self
.
summary_type
=
config
.
summary_type
if
hasattr
(
config
,
'summary_use_proj'
)
else
'last'
self
.
summary_type
=
config
.
summary_type
if
hasattr
(
config
,
'summary_use_proj'
)
else
'last'
...
...
pytorch_pretrained_bert/modeling_xlm.py
View file @
15b70338
...
@@ -35,7 +35,8 @@ from torch.nn import functional as F
...
@@ -35,7 +35,8 @@ from torch.nn import functional as F
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
.file_utils
import
cached_path
from
.file_utils
import
cached_path
from
.model_utils
import
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
,
prune_linear_layer
from
.model_utils
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
,
prune_linear_layer
,
SequenceSummary
,
SQuADHead
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -67,15 +68,23 @@ class XLMConfig(PretrainedConfig):
...
@@ -67,15 +68,23 @@ class XLMConfig(PretrainedConfig):
n_langs
=
1
,
n_langs
=
1
,
max_position_embeddings
=
512
,
max_position_embeddings
=
512
,
embed_init_std
=
2048
**
-
0.5
,
embed_init_std
=
2048
**
-
0.5
,
layer_norm_eps
=
1e-12
,
init_std
=
0.02
,
init_std
=
0.02
,
summary_type
=
"last"
,
use_proj
=
True
,
bos_index
=
0
,
bos_index
=
0
,
eos_index
=
1
,
eos_index
=
1
,
pad_index
=
2
,
pad_index
=
2
,
unk_index
=
3
,
unk_index
=
3
,
mask_index
=
5
,
mask_index
=
5
,
is_encoder
=
True
,
is_encoder
=
True
,
finetuning_task
=
None
,
num_labels
=
2
,
summary_type
=
'last'
,
summary_use_proj
=
True
,
summary_activation
=
'tanh'
,
summary_dropout
=
0.1
,
start_n_top
=
5
,
end_n_top
=
5
,
**
kwargs
):
**
kwargs
):
"""Constructs XLMConfig.
"""Constructs XLMConfig.
...
@@ -140,8 +149,7 @@ class XLMConfig(PretrainedConfig):
...
@@ -140,8 +149,7 @@ class XLMConfig(PretrainedConfig):
self
.
causal
=
causal
self
.
causal
=
causal
self
.
asm
=
asm
self
.
asm
=
asm
self
.
n_langs
=
n_langs
self
.
n_langs
=
n_langs
self
.
summary_type
=
summary_type
self
.
layer_norm_eps
=
layer_norm_eps
self
.
use_proj
=
use_proj
self
.
bos_index
=
bos_index
self
.
bos_index
=
bos_index
self
.
eos_index
=
eos_index
self
.
eos_index
=
eos_index
self
.
pad_index
=
pad_index
self
.
pad_index
=
pad_index
...
@@ -151,6 +159,14 @@ class XLMConfig(PretrainedConfig):
...
@@ -151,6 +159,14 @@ class XLMConfig(PretrainedConfig):
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
self
.
embed_init_std
=
embed_init_std
self
.
embed_init_std
=
embed_init_std
self
.
init_std
=
init_std
self
.
init_std
=
init_std
self
.
finetuning_task
=
finetuning_task
self
.
num_labels
=
num_labels
self
.
summary_type
=
summary_type
self
.
summary_use_proj
=
summary_use_proj
self
.
summary_activation
=
summary_activation
self
.
summary_dropout
=
summary_dropout
self
.
start_n_top
=
start_n_top
self
.
end_n_top
=
end_n_top
else
:
else
:
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
"or the path to a pretrained model config file (str)"
)
...
@@ -172,26 +188,6 @@ class XLMConfig(PretrainedConfig):
...
@@ -172,26 +188,6 @@ class XLMConfig(PretrainedConfig):
return
self
.
n_layers
return
self
.
n_layers
def
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
=
None
,
config
=
None
):
m
=
nn
.
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
=
padding_idx
)
if
config
is
not
None
and
config
.
embed_init_std
is
not
None
:
nn
.
init
.
normal_
(
m
.
weight
,
mean
=
0
,
std
=
config
.
embed_init_std
)
if
padding_idx
is
not
None
:
nn
.
init
.
constant_
(
m
.
weight
[
padding_idx
],
0
)
return
m
def
Linear
(
in_features
,
out_features
,
bias
=
True
,
config
=
None
):
m
=
nn
.
Linear
(
in_features
,
out_features
,
bias
)
if
config
is
not
None
and
config
.
init_std
is
not
None
:
nn
.
init
.
normal_
(
m
.
weight
,
mean
=
0
,
std
=
config
.
init_std
)
if
bias
:
nn
.
init
.
constant_
(
m
.
bias
,
0.
)
# nn.init.xavier_uniform_(m.weight)
# nn.init.constant_(m.bias, 0.)
return
m
def
create_sinusoidal_embeddings
(
n_pos
,
dim
,
out
):
def
create_sinusoidal_embeddings
(
n_pos
,
dim
,
out
):
position_enc
=
np
.
array
([
position_enc
=
np
.
array
([
[
pos
/
np
.
power
(
10000
,
2
*
(
j
//
2
)
/
dim
)
for
j
in
range
(
dim
)]
[
pos
/
np
.
power
(
10000
,
2
*
(
j
//
2
)
/
dim
)
for
j
in
range
(
dim
)]
...
@@ -244,7 +240,7 @@ class MultiHeadAttention(nn.Module):
...
@@ -244,7 +240,7 @@ class MultiHeadAttention(nn.Module):
NEW_ID
=
itertools
.
count
()
NEW_ID
=
itertools
.
count
()
def
__init__
(
self
,
n_heads
,
dim
,
config
):
def
__init__
(
self
,
n_heads
,
dim
,
config
):
super
().
__init__
()
super
(
MultiHeadAttention
,
self
).
__init__
()
self
.
layer_id
=
next
(
MultiHeadAttention
.
NEW_ID
)
self
.
layer_id
=
next
(
MultiHeadAttention
.
NEW_ID
)
self
.
output_attentions
=
config
.
output_attentions
self
.
output_attentions
=
config
.
output_attentions
self
.
dim
=
dim
self
.
dim
=
dim
...
@@ -252,10 +248,10 @@ class MultiHeadAttention(nn.Module):
...
@@ -252,10 +248,10 @@ class MultiHeadAttention(nn.Module):
self
.
dropout
=
config
.
attention_dropout
self
.
dropout
=
config
.
attention_dropout
assert
self
.
dim
%
self
.
n_heads
==
0
assert
self
.
dim
%
self
.
n_heads
==
0
self
.
q_lin
=
Linear
(
dim
,
dim
,
config
=
config
)
self
.
q_lin
=
nn
.
Linear
(
dim
,
dim
)
self
.
k_lin
=
Linear
(
dim
,
dim
,
config
=
config
)
self
.
k_lin
=
nn
.
Linear
(
dim
,
dim
)
self
.
v_lin
=
Linear
(
dim
,
dim
,
config
=
config
)
self
.
v_lin
=
nn
.
Linear
(
dim
,
dim
)
self
.
out_lin
=
Linear
(
dim
,
dim
,
config
=
config
)
self
.
out_lin
=
nn
.
Linear
(
dim
,
dim
)
def
prune_heads
(
self
,
heads
):
def
prune_heads
(
self
,
heads
):
attention_head_size
=
self
.
dim
//
self
.
n_heads
attention_head_size
=
self
.
dim
//
self
.
n_heads
...
@@ -342,10 +338,10 @@ class MultiHeadAttention(nn.Module):
...
@@ -342,10 +338,10 @@ class MultiHeadAttention(nn.Module):
class
TransformerFFN
(
nn
.
Module
):
class
TransformerFFN
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
dim_hidden
,
out_dim
,
config
):
def
__init__
(
self
,
in_dim
,
dim_hidden
,
out_dim
,
config
):
super
().
__init__
()
super
(
TransformerFFN
,
self
).
__init__
()
self
.
dropout
=
config
.
dropout
self
.
dropout
=
config
.
dropout
self
.
lin1
=
Linear
(
in_dim
,
dim_hidden
,
config
=
config
)
self
.
lin1
=
nn
.
Linear
(
in_dim
,
dim_hidden
)
self
.
lin2
=
Linear
(
dim_hidden
,
out_dim
,
config
=
config
)
self
.
lin2
=
nn
.
Linear
(
dim_hidden
,
out_dim
)
self
.
act
=
gelu
if
config
.
gelu_activation
else
F
.
relu
self
.
act
=
gelu
if
config
.
gelu_activation
else
F
.
relu
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
...
@@ -363,17 +359,21 @@ class XLMPreTrainedModel(PreTrainedModel):
...
@@ -363,17 +359,21 @@ class XLMPreTrainedModel(PreTrainedModel):
config_class
=
XLMConfig
config_class
=
XLMConfig
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights
=
None
load_tf_weights
=
None
base_model_prefix
=
"
xlm
"
base_model_prefix
=
"
transformer
"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
XLMPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
super
(
XLMPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
def
init_weights
(
self
,
module
):
""" Initialize the weights.
""" Initialize the weights. """
"""
if
isinstance
(
module
,
nn
.
Embedding
):
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Embedding
)):
if
self
.
config
is
not
None
and
self
.
config
.
embed_init_std
is
not
None
:
# Weights are initialized in module instantiation (see above)
nn
.
init
.
normal_
(
module
.
weight
,
mean
=
0
,
std
=
self
.
config
.
embed_init_std
)
pass
if
isinstance
(
module
,
nn
.
Linear
):
if
self
.
config
is
not
None
and
self
.
config
.
init_std
is
not
None
:
nn
.
init
.
normal_
(
module
.
weight
,
mean
=
0
,
std
=
self
.
config
.
init_std
)
if
hasattr
(
module
,
'bias'
)
and
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
0.
)
if
isinstance
(
module
,
nn
.
LayerNorm
):
if
isinstance
(
module
,
nn
.
LayerNorm
):
module
.
bias
.
data
.
zero_
()
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
module
.
weight
.
data
.
fill_
(
1.0
)
...
@@ -471,13 +471,13 @@ class XLMModel(XLMPreTrainedModel):
...
@@ -471,13 +471,13 @@ class XLMModel(XLMPreTrainedModel):
assert
self
.
dim
%
self
.
n_heads
==
0
,
'transformer dim must be a multiple of n_heads'
assert
self
.
dim
%
self
.
n_heads
==
0
,
'transformer dim must be a multiple of n_heads'
# embeddings
# embeddings
self
.
position_embeddings
=
Embedding
(
config
.
max_position_embeddings
,
self
.
dim
,
config
=
config
)
self
.
position_embeddings
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
self
.
dim
)
if
config
.
sinusoidal_embeddings
:
if
config
.
sinusoidal_embeddings
:
create_sinusoidal_embeddings
(
config
.
max_position_embeddings
,
self
.
dim
,
out
=
self
.
position_embeddings
.
weight
)
create_sinusoidal_embeddings
(
config
.
max_position_embeddings
,
self
.
dim
,
out
=
self
.
position_embeddings
.
weight
)
if
config
.
n_langs
>
1
:
if
config
.
n_langs
>
1
:
self
.
lang_embeddings
=
Embedding
(
self
.
n_langs
,
self
.
dim
,
config
=
config
)
self
.
lang_embeddings
=
nn
.
Embedding
(
self
.
n_langs
,
self
.
dim
)
self
.
embeddings
=
Embedding
(
self
.
n_words
,
self
.
dim
,
padding_idx
=
self
.
pad_index
,
config
=
config
)
self
.
embeddings
=
nn
.
Embedding
(
self
.
n_words
,
self
.
dim
,
padding_idx
=
self
.
pad_index
)
self
.
layer_norm_emb
=
nn
.
LayerNorm
(
self
.
dim
,
eps
=
1e-12
)
self
.
layer_norm_emb
=
nn
.
LayerNorm
(
self
.
dim
,
eps
=
config
.
layer_norm_eps
)
# transformer layers
# transformer layers
self
.
attentions
=
nn
.
ModuleList
()
self
.
attentions
=
nn
.
ModuleList
()
...
@@ -490,12 +490,14 @@ class XLMModel(XLMPreTrainedModel):
...
@@ -490,12 +490,14 @@ class XLMModel(XLMPreTrainedModel):
for
_
in
range
(
self
.
n_layers
):
for
_
in
range
(
self
.
n_layers
):
self
.
attentions
.
append
(
MultiHeadAttention
(
self
.
n_heads
,
self
.
dim
,
config
=
config
))
self
.
attentions
.
append
(
MultiHeadAttention
(
self
.
n_heads
,
self
.
dim
,
config
=
config
))
self
.
layer_norm1
.
append
(
nn
.
LayerNorm
(
self
.
dim
,
eps
=
1e-12
))
self
.
layer_norm1
.
append
(
nn
.
LayerNorm
(
self
.
dim
,
eps
=
config
.
layer_norm_eps
))
# if self.is_decoder:
# if self.is_decoder:
# self.layer_norm15.append(nn.LayerNorm(self.dim, eps=
1e-12
))
# self.layer_norm15.append(nn.LayerNorm(self.dim, eps=
config.layer_norm_eps
))
# self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
# self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
self
.
ffns
.
append
(
TransformerFFN
(
self
.
dim
,
self
.
hidden_dim
,
self
.
dim
,
config
=
config
))
self
.
ffns
.
append
(
TransformerFFN
(
self
.
dim
,
self
.
hidden_dim
,
self
.
dim
,
config
=
config
))
self
.
layer_norm2
.
append
(
nn
.
LayerNorm
(
self
.
dim
,
eps
=
1e-12
))
self
.
layer_norm2
.
append
(
nn
.
LayerNorm
(
self
.
dim
,
eps
=
config
.
layer_norm_eps
))
self
.
apply
(
self
.
init_weights
)
def
_prune_heads
(
self
,
heads_to_prune
):
def
_prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
""" Prunes heads of the model.
...
@@ -636,14 +638,14 @@ class XLMPredLayer(nn.Module):
...
@@ -636,14 +638,14 @@ class XLMPredLayer(nn.Module):
Prediction layer (cross_entropy or adaptive_softmax).
Prediction layer (cross_entropy or adaptive_softmax).
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
()
super
(
XLMPredLayer
,
self
).
__init__
()
self
.
asm
=
config
.
asm
self
.
asm
=
config
.
asm
self
.
n_words
=
config
.
n_words
self
.
n_words
=
config
.
n_words
self
.
pad_index
=
config
.
pad_index
self
.
pad_index
=
config
.
pad_index
dim
=
config
.
emb_dim
dim
=
config
.
emb_dim
if
config
.
asm
is
False
:
if
config
.
asm
is
False
:
self
.
proj
=
Linear
(
dim
,
config
.
n_words
,
bias
=
True
)
self
.
proj
=
nn
.
Linear
(
dim
,
config
.
n_words
,
bias
=
True
)
else
:
else
:
self
.
proj
=
nn
.
AdaptiveLogSoftmaxWithLoss
(
self
.
proj
=
nn
.
AdaptiveLogSoftmaxWithLoss
(
in_features
=
dim
,
in_features
=
dim
,
...
@@ -653,28 +655,24 @@ class XLMPredLayer(nn.Module):
...
@@ -653,28 +655,24 @@ class XLMPredLayer(nn.Module):
head_bias
=
True
,
# default is False
head_bias
=
True
,
# default is False
)
)
def
forward
(
self
,
x
,
y
,
get_scores
=
False
):
def
forward
(
self
,
x
,
y
=
None
):
""" Compute the loss, and optionally the scores.
"""
"""
Compute the loss, and optionally the scores.
outputs
=
()
"""
assert
(
y
==
self
.
pad_index
).
sum
().
item
()
==
0
if
self
.
asm
is
False
:
if
self
.
asm
is
False
:
scores
=
self
.
proj
(
x
).
view
(
-
1
,
self
.
n_words
)
scores
=
self
.
proj
(
x
).
view
(
-
1
,
self
.
n_words
)
loss
=
F
.
cross_entropy
(
scores
,
y
,
reduction
=
'elementwise_mean'
)
outputs
=
(
scores
,)
+
outputs
if
y
is
not
None
:
loss
=
F
.
cross_entropy
(
scores
,
y
,
reduction
=
'elementwise_mean'
)
outputs
=
(
loss
,)
+
outputs
else
:
else
:
_
,
loss
=
self
.
proj
(
x
,
y
)
scores
=
self
.
proj
.
log_prob
(
x
)
scores
=
self
.
proj
.
log_prob
(
x
)
if
get_scores
else
None
outputs
=
(
scores
,)
+
outputs
if
y
is
not
None
:
return
scores
,
loss
_
,
loss
=
self
.
proj
(
x
,
y
)
outputs
=
(
loss
,)
+
outputs
def
get_scores
(
self
,
x
):
"""
Compute scores.
"""
assert
x
.
dim
()
==
2
return
self
.
proj
.
log_prob
(
x
)
if
self
.
asm
else
self
.
proj
(
x
)
return
outputs
class
XLMWithLMHeadModel
(
XLMPreTrainedModel
):
class
XLMWithLMHeadModel
(
XLMPreTrainedModel
):
...
@@ -731,6 +729,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
...
@@ -731,6 +729,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
XLMWithLMHeadModel
,
self
).
__init__
(
config
)
super
(
XLMWithLMHeadModel
,
self
).
__init__
(
config
)
self
.
torchscript
=
config
.
torchscript
self
.
transformer
=
XLMModel
(
config
)
self
.
transformer
=
XLMModel
(
config
)
self
.
pred_layer
=
XLMPredLayer
(
config
)
self
.
pred_layer
=
XLMPredLayer
(
config
)
...
@@ -741,7 +740,10 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
...
@@ -741,7 +740,10 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
def
tie_weights
(
self
):
def
tie_weights
(
self
):
""" Make sure we are sharing the embeddings
""" Make sure we are sharing the embeddings
"""
"""
self
.
pred_layer
.
proj
.
weight
=
self
.
transformer
.
embeddings
.
weight
if
self
.
torchscript
:
self
.
pred_layer
.
proj
.
weight
=
nn
.
Parameter
(
self
.
transformer
.
embeddings
.
weight
.
clone
())
else
:
self
.
pred_layer
.
proj
.
weight
=
self
.
transformer
.
embeddings
.
weight
def
forward
(
self
,
input_ids
,
lengths
=
None
,
positions
=
None
,
langs
=
None
,
token_type_ids
=
None
,
def
forward
(
self
,
input_ids
,
lengths
=
None
,
positions
=
None
,
langs
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
cache
=
None
,
labels
=
None
,
head_mask
=
None
):
attention_mask
=
None
,
cache
=
None
,
labels
=
None
,
head_mask
=
None
):
...
@@ -775,55 +777,12 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
...
@@ -775,55 +777,12 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
langs
=
langs
,
attention_mask
=
attention_mask
,
cache
=
cache
,
head_mask
=
head_mask
)
langs
=
langs
,
attention_mask
=
attention_mask
,
cache
=
cache
,
head_mask
=
head_mask
)
output
=
transformer_outputs
[
0
]
output
=
transformer_outputs
[
0
]
logits
=
self
.
pred_layer
(
output
,
labels
)
outputs
=
self
.
pred_layer
(
output
,
labels
)
outputs
=
outputs
+
transformer_outputs
[
1
:]
# Keep new_mems and attention/hidden states if they are here
outputs
=
transformer_outputs
[
1
:]
# Keep new_mems and attention/hidden states if they are here
if
labels
is
not
None
:
# Flatten the tokens
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
logits
.
view
(
-
1
,
logits
.
size
(
-
1
)),
labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
[
logits
]
+
outputs
return
outputs
return
outputs
class
XLMSequenceSummary
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
XLMSequenceSummary
,
self
).
__init__
()
self
.
summary_type
=
config
.
summary_type
if
config
.
use_proj
:
self
.
summary
=
nn
.
Linear
(
config
.
d_model
,
config
.
d_model
)
else
:
self
.
summary
=
None
if
config
.
summary_type
==
'attn'
:
# We should use a standard multi-head attention module with absolute positional embedding for that.
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
raise
NotImplementedError
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
self
.
activation
=
nn
.
Tanh
()
def
forward
(
self
,
hidden_states
):
""" hidden_states: float Tensor in shape [bsz, seq_len, d_model], the hidden-states of the last layer."""
if
self
.
summary_type
==
'last'
:
output
=
hidden_states
[:,
-
1
]
elif
self
.
summary_type
==
'first'
:
output
=
hidden_states
[:,
0
]
elif
self
.
summary_type
==
'mean'
:
output
=
hidden_states
.
mean
(
dim
=
1
)
elif
summary_type
==
'attn'
:
raise
NotImplementedError
output
=
self
.
summary
(
output
)
output
=
self
.
activation
(
output
)
output
=
self
.
dropout
(
output
)
return
output
class
XLMForSequenceClassification
(
XLMPreTrainedModel
):
class
XLMForSequenceClassification
(
XLMPreTrainedModel
):
"""XLM model ("XLM: Generalized Autoregressive Pretraining for Language Understanding").
"""XLM model ("XLM: Generalized Autoregressive Pretraining for Language Understanding").
...
@@ -890,15 +849,15 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
...
@@ -890,15 +849,15 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
XLMForSequenceClassification
,
self
).
__init__
(
config
)
super
(
XLMForSequenceClassification
,
self
).
__init__
(
config
)
self
.
num_labels
=
config
.
num_labels
self
.
transformer
=
XLMModel
(
config
)
self
.
transformer
=
XLMModel
(
config
)
self
.
sequence_summary
=
XLMSequenceSummary
(
config
)
self
.
sequence_summary
=
SequenceSummary
(
config
)
self
.
logits_proj
=
nn
.
Linear
(
config
.
d_model
,
config
.
num_labels
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
lengths
=
None
,
positions
=
None
,
langs
=
None
,
attention_mask
=
None
,
def
forward
(
self
,
input_ids
,
lengths
=
None
,
positions
=
None
,
langs
=
None
,
token_type_ids
=
None
,
cache
=
None
,
labels
=
None
,
head_mask
=
None
):
attention_mask
=
None
,
cache
=
None
,
labels
=
None
,
head_mask
=
None
):
"""
"""
Args:
Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
...
@@ -930,10 +889,9 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
...
@@ -930,10 +889,9 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
langs
=
langs
,
attention_mask
=
attention_mask
,
cache
=
cache
,
head_mask
=
head_mask
)
langs
=
langs
,
attention_mask
=
attention_mask
,
cache
=
cache
,
head_mask
=
head_mask
)
output
=
transformer_outputs
[
0
]
output
=
transformer_outputs
[
0
]
output
=
self
.
sequence_summary
(
output
)
logits
=
self
.
sequence_summary
(
output
)
logits
=
self
.
logits_proj
(
output
)
outputs
=
transformer_outputs
[
1
:]
# Keep new_mems and attention/hidden states if they are here
outputs
=
(
logits
,)
+
transformer_outputs
[
1
:]
# Keep new_mems and attention/hidden states if they are here
if
labels
is
not
None
:
if
labels
is
not
None
:
if
self
.
num_labels
==
1
:
if
self
.
num_labels
==
1
:
...
@@ -943,9 +901,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
...
@@ -943,9 +901,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
else
:
else
:
loss_fct
=
CrossEntropyLoss
()
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
outputs
=
[
logits
]
+
outputs
return
outputs
return
outputs
...
@@ -1010,41 +966,22 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
...
@@ -1010,41 +966,22 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
super
(
XLMForQuestionAnswering
,
self
).
__init__
(
config
)
super
(
XLMForQuestionAnswering
,
self
).
__init__
(
config
)
self
.
transformer
=
XLMModel
(
config
)
self
.
transformer
=
XLMModel
(
config
)
self
.
qa_outputs
=
nn
.
Lin
ea
r
(
config
.
hidden_size
,
config
.
num_labels
)
self
.
qa_outputs
=
SQuADH
ea
d
(
config
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
lengths
=
None
,
positions
=
None
,
langs
=
None
,
attention_mask
=
None
,
cache
=
None
,
def
forward
(
self
,
input_ids
,
lengths
=
None
,
positions
=
None
,
langs
=
None
,
token_type_ids
=
None
,
labels
=
None
,
head_mask
=
None
):
attention_mask
=
None
,
cache
=
None
,
start_positions
=
None
,
end_positions
=
None
,
cls_index
=
None
,
is_impossible
=
None
,
p_mask
=
None
,
head_mask
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
lengths
=
lengths
,
positions
=
positions
,
token_type_ids
=
token_type_ids
,
transformer_outputs
=
self
.
transformer
(
input_ids
,
lengths
=
lengths
,
positions
=
positions
,
token_type_ids
=
token_type_ids
,
langs
=
langs
,
attention_mask
=
attention_mask
,
cache
=
cache
,
head_mask
=
head_mask
)
langs
=
langs
,
attention_mask
=
attention_mask
,
cache
=
cache
,
head_mask
=
head_mask
)
output
=
transformer_outputs
[
0
]
output
=
transformer_outputs
[
0
]
logits
=
self
.
qa_outputs
(
output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
outputs
=
self
.
qa_outputs
(
output
,
start_positions
=
start_positions
,
end_positions
=
end_positions
,
start_logits
=
start_logits
.
squeeze
(
-
1
)
cls_index
=
cls_index
,
is_impossible
=
is_impossible
,
p_mask
=
p_mask
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
outputs
=
outputs
+
transformer_outputs
[
1
:]
# Keep new_mems and attention/hidden states if they are here
outputs
=
transformer_outputs
[
1
:]
# Keep new_mems and attention/hidden states if they are here
if
start_positions
is
not
None
and
end_positions
is
not
None
:
# If we are on multi-GPU, split add a dimension
if
len
(
start_positions
.
size
())
>
1
:
start_positions
=
start_positions
.
squeeze
(
-
1
)
if
len
(
end_positions
.
size
())
>
1
:
end_positions
=
end_positions
.
squeeze
(
-
1
)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index
=
start_logits
.
size
(
1
)
start_positions
.
clamp_
(
0
,
ignored_index
)
end_positions
.
clamp_
(
0
,
ignored_index
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=
ignored_index
)
start_loss
=
loss_fct
(
start_logits
,
start_positions
)
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2
outputs
=
[
total_loss
]
+
outputs
outputs
=
[
start_logits
,
end_logits
]
+
outputs
return
outputs
return
outputs
pytorch_pretrained_bert/modeling_xlnet.py
View file @
15b70338
...
@@ -32,8 +32,8 @@ from torch.nn import functional as F
...
@@ -32,8 +32,8 @@ from torch.nn import functional as F
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
.file_utils
import
cached_path
from
.file_utils
import
cached_path
from
.model_utils
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
from
.model_utils
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
,
PretrainedConfig
,
PreTrainedModel
,
SequenceSummary
)
SequenceSummary
,
PoolerAnswerClass
,
PoolerEndLogits
,
PoolerStartLogits
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -221,13 +221,15 @@ class XLNetConfig(PretrainedConfig):
...
@@ -221,13 +221,15 @@ class XLNetConfig(PretrainedConfig):
bi_data
=
False
,
bi_data
=
False
,
clamp_len
=-
1
,
clamp_len
=-
1
,
same_length
=
False
,
same_length
=
False
,
finetuning_task
=
None
,
finetuning_task
=
None
,
num_labels
=
2
,
num_labels
=
2
,
summary_type
=
'last'
,
summary_type
=
'last'
,
summary_use_proj
=
True
,
summary_use_proj
=
True
,
summary_activation
=
'tanh'
,
summary_activation
=
'tanh'
,
summary_dropout
=
0.1
,
summary_dropout
=
0.1
,
start_n_top
=
5
,
end_n_top
=
5
,
**
kwargs
):
**
kwargs
):
"""Constructs XLNetConfig.
"""Constructs XLNetConfig.
...
@@ -313,6 +315,8 @@ class XLNetConfig(PretrainedConfig):
...
@@ -313,6 +315,8 @@ class XLNetConfig(PretrainedConfig):
self
.
summary_use_proj
=
summary_use_proj
self
.
summary_use_proj
=
summary_use_proj
self
.
summary_activation
=
summary_activation
self
.
summary_activation
=
summary_activation
self
.
summary_dropout
=
summary_dropout
self
.
summary_dropout
=
summary_dropout
self
.
start_n_top
=
start_n_top
self
.
end_n_top
=
end_n_top
else
:
else
:
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
"or the path to a pretrained model config file (str)"
)
...
@@ -1114,6 +1118,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
...
@@ -1114,6 +1118,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
XLNetForSequenceClassification
,
self
).
__init__
(
config
)
super
(
XLNetForSequenceClassification
,
self
).
__init__
(
config
)
self
.
num_labels
=
config
.
num_labels
self
.
transformer
=
XLNetModel
(
config
)
self
.
transformer
=
XLNetModel
(
config
)
self
.
sequence_summary
=
SequenceSummary
(
config
)
self
.
sequence_summary
=
SequenceSummary
(
config
)
...
@@ -1174,7 +1179,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
...
@@ -1174,7 +1179,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
class
XLNetForQuestionAnswering
(
XLNetPreTrainedModel
):
class
XLNetForQuestionAnswering
(
XLNetPreTrainedModel
):
"""XLNet model for Question Answering (span extraction).
"""
XLNet model for Question Answering (span extraction).
This module is composed of the XLNet model with a linear layer on top of
This module is composed of the XLNet model with a linear layer on top of
the sequence output that computes start_logits and end_logits
the sequence output that computes start_logits and end_logits
...
@@ -1231,41 +1236,78 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
...
@@ -1231,41 +1236,78 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
XLNetForQuestionAnswering
,
self
).
__init__
(
config
)
super
(
XLNetForQuestionAnswering
,
self
).
__init__
(
config
)
self
.
start_n_top
=
config
.
start_n_top
self
.
end_n_top
=
config
.
end_n_top
self
.
transformer
=
XLNetModel
(
config
)
self
.
transformer
=
XLNetModel
(
config
)
self
.
qa_outputs
=
nn
.
Linear
(
config
.
d_model
,
config
.
num_labels
)
self
.
start_logits
=
PoolerStartLogits
(
config
)
self
.
end_logits
=
PoolerEndLogits
(
config
)
self
.
answer_class
=
PoolerAnswerClass
(
config
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
start_positions
=
None
,
end_positions
=
None
,
head_mask
=
None
):
start_positions
=
None
,
end_positions
=
None
,
cls_index
=
None
,
is_impossible
=
None
,
p_mask
=
None
,
head_mask
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
token_type_ids
,
input_mask
,
attention_mask
,
transformer_outputs
=
self
.
transformer
(
input_ids
,
token_type_ids
,
input_mask
,
attention_mask
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
head_mask
)
mems
,
perm_mask
,
target_mapping
,
inp_q
,
head_mask
)
hidden_states
=
transformer_outputs
[
0
]
start_logits
=
self
.
start_logits
(
hidden_states
,
p_mask
)
logits
=
self
.
qa_
outputs
(
transformer_outputs
[
0
])
outputs
=
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
if
start_positions
is
not
None
and
end_positions
is
not
None
:
start_logits
=
start_logits
.
squeeze
(
-
1
)
# If we are on multi-GPU, let's remove the dimension added by batch splitting
end_logits
=
end_logits
.
squeeze
(
-
1
)
for
x
in
(
start_positions
,
end_positions
,
cls_index
,
is_impossible
):
if
x
is
not
None
and
x
.
dim
()
>
1
:
x
.
squeeze_
(
-
1
)
outputs
=
(
start_logits
,
end_logits
,)
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
# during training, compute the end logits based on the ground truth of the start position
end_logits
=
self
.
end_logits
(
hidden_states
,
start_positions
=
start_positions
,
p_mask
=
p_mask
)
if
start_positions
is
not
None
and
end_positions
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
# If we are on multi-GPU, split add a dimension
if
len
(
start_positions
.
size
())
>
1
:
start_positions
=
start_positions
.
squeeze
(
-
1
)
if
len
(
end_positions
.
size
())
>
1
:
end_positions
=
end_positions
.
squeeze
(
-
1
)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index
=
start_logits
.
size
(
1
)
start_positions
.
clamp_
(
0
,
ignored_index
)
end_positions
.
clamp_
(
0
,
ignored_index
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=
ignored_index
)
start_loss
=
loss_fct
(
start_logits
,
start_positions
)
start_loss
=
loss_fct
(
start_logits
,
start_positions
)
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2
total_loss
=
(
start_loss
+
end_loss
)
/
2
outputs
=
(
total_loss
,)
+
outputs
return
outputs
# return (loss), logits, (mems), (hidden states), (attentions)
if
cls_index
is
not
None
and
is_impossible
is
not
None
:
# Predict answerability from the representation of CLS and START
cls_logits
=
self
.
answer_class
(
hidden_states
,
start_positions
=
start_positions
,
cls_index
=
cls_index
)
loss_fct_cls
=
nn
.
BCEWithLogitsLoss
()
cls_loss
=
loss_fct_cls
(
cls_logits
,
is_impossible
)
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is
# comparable to start_loss and end_loss
total_loss
+=
cls_loss
*
0.5
outputs
=
(
total_loss
,
start_logits
,
end_logits
,
cls_logits
)
+
outputs
else
:
outputs
=
(
total_loss
,
start_logits
,
end_logits
)
+
outputs
else
:
# during inference, compute the end logits based on beam search
bsz
,
slen
,
hsz
=
hidden_states
.
size
()
start_log_probs
=
F
.
softmax
(
start_logits
,
dim
=-
1
)
# shape (bsz, slen)
start_top_log_probs
,
start_top_index
=
torch
.
topk
(
start_log_probs
,
self
.
start_n_top
,
dim
=-
1
)
# shape (bsz, start_n_top)
start_top_index
=
start_top_index
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
hsz
)
# shape (bsz, start_n_top, hsz)
start_states
=
torch
.
gather
(
hidden_states
,
-
2
,
start_top_index
)
# shape (bsz, start_n_top, hsz)
start_states
=
start_states
.
unsqueeze
(
1
).
expand
(
-
1
,
slen
,
-
1
,
-
1
)
# shape (bsz, slen, start_n_top, hsz)
hidden_states_expanded
=
hidden_states
.
unsqueeze
(
2
).
expand_as
(
start_states
)
# shape (bsz, slen, start_n_top, hsz)
p_mask
=
p_mask
.
unsqueeze
(
-
1
)
if
p_mask
is
not
None
else
None
end_logits
=
self
.
end_logits
(
hidden_states_expanded
,
start_states
=
start_states
,
p_mask
=
p_mask
)
end_log_probs
=
F
.
softmax
(
end_logits
,
dim
=
1
)
# shape (bsz, slen, start_n_top)
end_top_log_probs
,
end_top_index
=
torch
.
topk
(
end_log_probs
,
self
.
end_n_top
,
dim
=
1
)
# shape (bsz, end_n_top, start_n_top)
end_top_log_probs
=
end_top_log_probs
.
view
(
-
1
,
self
.
start_n_top
*
self
.
end_n_top
)
end_top_index
=
end_top_index
.
view
(
-
1
,
self
.
start_n_top
*
self
.
end_n_top
)
start_states
=
torch
.
einsum
(
"blh,bl->bh"
,
hidden_states
,
start_log_probs
)
cls_logits
=
self
.
answer_class
(
hidden_states
,
start_states
=
start_states
,
cls_index
=
cls_index
)
outputs
=
(
start_top_log_probs
,
start_top_index
,
end_top_log_probs
,
end_top_index
,
cls_logits
)
+
outputs
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems, (hidden states), (attentions)
# or (if labels are provided) total_loss, start_logits, end_logits, (cls_logits), mems, (hidden states), (attentions)
return
outputs
pytorch_pretrained_bert/tests/modeling_openai_test.py
View file @
15b70338
...
@@ -16,11 +16,7 @@ from __future__ import absolute_import
...
@@ -16,11 +16,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
os
import
unittest
import
unittest
import
json
import
random
import
shutil
import
pytest
import
pytest
import
torch
import
torch
...
...
pytorch_pretrained_bert/tests/modeling_xlm_test.py
View file @
15b70338
...
@@ -20,7 +20,7 @@ import unittest
...
@@ -20,7 +20,7 @@ import unittest
import
shutil
import
shutil
import
pytest
import
pytest
from
pytorch_pretrained_bert
import
(
XLMConfig
,
XLMModel
,
XLMForQuestionAnswering
,
XLMForSequenceClassification
)
from
pytorch_pretrained_bert
import
(
XLMConfig
,
XLMModel
,
XLMWithLMHeadModel
,
XLMForQuestionAnswering
,
XLMForSequenceClassification
)
from
pytorch_pretrained_bert.modeling_xlm
import
PRETRAINED_MODEL_ARCHIVE_MAP
from
pytorch_pretrained_bert.modeling_xlm
import
PRETRAINED_MODEL_ARCHIVE_MAP
from
.model_tests_commons
import
(
create_and_check_commons
,
ConfigTester
,
ids_tensor
)
from
.model_tests_commons
import
(
create_and_check_commons
,
ConfigTester
,
ids_tensor
)
...
@@ -58,7 +58,8 @@ class XLMModelTest(unittest.TestCase):
...
@@ -58,7 +58,8 @@ class XLMModelTest(unittest.TestCase):
summary_type
=
"last"
,
summary_type
=
"last"
,
use_proj
=
True
,
use_proj
=
True
,
scope
=
None
,
scope
=
None
,
all_model_classes
=
(
XLMModel
,),
# , XLMForSequenceClassification, XLMForTokenClassification),
all_model_classes
=
(
XLMModel
,
XLMWithLMHeadModel
,
XLMForQuestionAnswering
,
XLMForSequenceClassification
),
# , XLMForSequenceClassification, XLMForTokenClassification),
):
):
self
.
parent
=
parent
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
...
@@ -93,6 +94,7 @@ class XLMModelTest(unittest.TestCase):
...
@@ -93,6 +94,7 @@ class XLMModelTest(unittest.TestCase):
def
prepare_config_and_inputs
(
self
):
def
prepare_config_and_inputs
(
self
):
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
2
).
float
()
input_lengths
=
None
input_lengths
=
None
if
self
.
use_input_lengths
:
if
self
.
use_input_lengths
:
...
@@ -104,11 +106,11 @@ class XLMModelTest(unittest.TestCase):
...
@@ -104,11 +106,11 @@ class XLMModelTest(unittest.TestCase):
sequence_labels
=
None
sequence_labels
=
None
token_labels
=
None
token_labels
=
None
choic
e_labels
=
None
is_impossibl
e_labels
=
None
if
self
.
use_labels
:
if
self
.
use_labels
:
sequence_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
sequence_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choic
e_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
is_impossibl
e_labels
=
ids_tensor
([
self
.
batch_size
],
2
).
float
(
)
config
=
XLMConfig
(
config
=
XLMConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
vocab_size_or_config_json_file
=
self
.
vocab_size
,
...
@@ -128,14 +130,14 @@ class XLMModelTest(unittest.TestCase):
...
@@ -128,14 +130,14 @@ class XLMModelTest(unittest.TestCase):
summary_type
=
self
.
summary_type
,
summary_type
=
self
.
summary_type
,
use_proj
=
self
.
use_proj
)
use_proj
=
self
.
use_proj
)
return
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
choice_labels
return
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
def
check_loss_output
(
self
,
result
):
def
check_loss_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
list
(
result
[
"loss"
].
size
()),
[])
[])
def
create_and_check_xlm_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_and_check_xlm_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
):
model
=
XLMModel
(
config
=
config
)
model
=
XLMModel
(
config
=
config
)
model
.
eval
()
model
.
eval
()
outputs
=
model
(
input_ids
,
lengths
=
input_lengths
,
langs
=
token_type_ids
)
outputs
=
model
(
input_ids
,
lengths
=
input_lengths
,
langs
=
token_type_ids
)
...
@@ -150,90 +152,92 @@ class XLMModelTest(unittest.TestCase):
...
@@ -150,90 +152,92 @@ class XLMModelTest(unittest.TestCase):
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
# def create_and_check_xlm_for_masked_lm(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
def
create_and_check_xlm_lm_head
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
):
# model = XLMForMaskedLM(config=config)
model
=
XLMWithLMHeadModel
(
config
)
# model.eval()
model
.
eval
()
# loss, prediction_scores = model(input_ids, token_type_ids, input_lengths, token_labels)
# result = {
loss
,
logits
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
)
# "loss": loss,
# "prediction_scores": prediction_scores,
result
=
{
# }
"loss"
:
loss
,
# self.parent.assertListEqual(
"logits"
:
logits
,
# list(result["prediction_scores"].size()),
}
# [self.batch_size, self.seq_length, self.vocab_size])
# self.check_loss_output(result)
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
# def create_and_check_xlm_for_question_answering(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
self
.
parent
.
assertListEqual
(
# model = XLMForQuestionAnswering(config=config)
list
(
result
[
"logits"
].
size
()),
# model.eval()
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
# loss, start_logits, end_logits = model(input_ids, token_type_ids, input_lengths, sequence_labels, sequence_labels)
# result = {
# "loss": loss,
def
create_and_check_xlm_qa
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
):
# "start_logits": start_logits,
model
=
XLMForQuestionAnswering
(
config
)
# "end_logits": end_logits,
model
.
eval
()
# }
# self.parent.assertListEqual(
outputs
=
model
(
input_ids
)
# list(result["start_logits"].size()),
start_top_log_probs
,
start_top_index
,
end_top_log_probs
,
end_top_index
,
cls_logits
,
mems
=
outputs
# [self.batch_size, self.seq_length])
# self.parent.assertListEqual(
outputs
=
model
(
input_ids
,
start_positions
=
sequence_labels
,
# list(result["end_logits"].size()),
end_positions
=
sequence_labels
,
# [self.batch_size, self.seq_length])
cls_index
=
sequence_labels
,
# self.check_loss_output(result)
is_impossible
=
is_impossible_labels
,
p_mask
=
input_mask
)
# def create_and_check_xlm_for_sequence_classification(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
outputs
=
model
(
input_ids
,
start_positions
=
sequence_labels
,
# config.num_labels = self.num_labels
end_positions
=
sequence_labels
,
# model = XLMForSequenceClassification(config)
cls_index
=
sequence_labels
,
# model.eval()
is_impossible
=
is_impossible_labels
)
# loss, logits = model(input_ids, token_type_ids, input_lengths, sequence_labels)
# result = {
total_loss
,
start_logits
,
end_logits
,
cls_logits
=
outputs
# "loss": loss,
# "logits": logits,
outputs
=
model
(
input_ids
,
start_positions
=
sequence_labels
,
# }
end_positions
=
sequence_labels
)
# self.parent.assertListEqual(
# list(result["logits"].size()),
total_loss
,
start_logits
,
end_logits
=
outputs
# [self.batch_size, self.num_labels])
# self.check_loss_output(result)
result
=
{
"loss"
:
total_loss
,
"start_logits"
:
start_logits
,
# def create_and_check_xlm_for_token_classification(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
"end_logits"
:
end_logits
,
# config.num_labels = self.num_labels
"cls_logits"
:
cls_logits
,
# model = XLMForTokenClassification(config=config)
}
# model.eval()
# loss, logits = model(input_ids, token_type_ids, input_lengths, token_labels)
self
.
parent
.
assertListEqual
(
# result = {
list
(
result
[
"loss"
].
size
()),
# "loss": loss,
[])
# "logits": logits,
self
.
parent
.
assertListEqual
(
# }
list
(
result
[
"start_logits"
].
size
()),
# self.parent.assertListEqual(
[
self
.
batch_size
,
self
.
seq_length
])
# list(result["logits"].size()),
self
.
parent
.
assertListEqual
(
# [self.batch_size, self.seq_length, self.num_labels])
list
(
result
[
"end_logits"
].
size
()),
# self.check_loss_output(result)
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"cls_logits"
].
size
()),
# def create_and_check_xlm_for_multiple_choice(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
[
self
.
batch_size
])
# config.num_choices = self.num_choices
# model = XLMForMultipleChoice(config=config)
# model.eval()
def
create_and_check_xlm_sequence_classif
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
):
# multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
model
=
XLMForSequenceClassification
(
config
)
# multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
model
.
eval
()
# multiple_choice_input_lengths = input_lengths.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
# loss, logits = model(multiple_choice_inputs_ids,
(
logits
,)
=
model
(
input_ids
)
# multiple_choice_token_type_ids,
loss
,
logits
=
model
(
input_ids
,
labels
=
sequence_labels
)
# multiple_choice_input_lengths,
# choice_labels)
result
=
{
# result = {
"loss"
:
loss
,
# "loss": loss,
"logits"
:
logits
,
# "logits": logits,
}
# }
# self.parent.assertListEqual(
self
.
parent
.
assertListEqual
(
# list(result["logits"].size()),
list
(
result
[
"loss"
].
size
()),
# [self.batch_size, self.num_choices])
[])
# self.check_loss_output(result)
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
type_sequence_label_size
])
def
create_and_check_xlm_commons
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_and_check_xlm_commons
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
):
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'lengths'
:
input_lengths
}
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'lengths'
:
input_lengths
}
create_and_check_commons
(
self
,
config
,
inputs_dict
)
create_and_check_commons
(
self
,
config
,
inputs_dict
)
...
...
pytorch_pretrained_bert/tests/modeling_xlnet_test.py
View file @
15b70338
...
@@ -49,6 +49,7 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -49,6 +49,7 @@ class XLNetModelTest(unittest.TestCase):
d_inner
=
128
,
d_inner
=
128
,
num_hidden_layers
=
5
,
num_hidden_layers
=
5
,
max_position_embeddings
=
10
,
max_position_embeddings
=
10
,
type_sequence_label_size
=
2
,
untie_r
=
True
,
untie_r
=
True
,
bi_data
=
False
,
bi_data
=
False
,
same_length
=
False
,
same_length
=
False
,
...
@@ -80,12 +81,14 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -80,12 +81,14 @@ class XLNetModelTest(unittest.TestCase):
self
.
initializer_range
=
initializer_range
self
.
initializer_range
=
initializer_range
self
.
seed
=
seed
self
.
seed
=
seed
self
.
type_vocab_size
=
type_vocab_size
self
.
type_vocab_size
=
type_vocab_size
self
.
type_sequence_label_size
=
type_sequence_label_size
self
.
all_model_classes
=
all_model_classes
self
.
all_model_classes
=
all_model_classes
def
prepare_config_and_inputs
(
self
):
def
prepare_config_and_inputs
(
self
):
input_ids_1
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_1
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_2
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_2
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
segment_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
segment_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
2
).
float
()
input_ids_q
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
+
1
],
self
.
vocab_size
)
input_ids_q
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
+
1
],
self
.
vocab_size
)
perm_mask
=
torch
.
zeros
(
self
.
batch_size
,
self
.
seq_length
+
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
)
perm_mask
=
torch
.
zeros
(
self
.
batch_size
,
self
.
seq_length
+
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
)
...
@@ -94,30 +97,13 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -94,30 +97,13 @@ class XLNetModelTest(unittest.TestCase):
target_mapping
[:,
0
,
-
1
]
=
1.0
# predict last token
target_mapping
[:,
0
,
-
1
]
=
1.0
# predict last token
inp_q
=
target_mapping
[:,
0
,
:].
clone
()
# predict last token
inp_q
=
target_mapping
[:,
0
,
:].
clone
()
# predict last token
# inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
sequence_labels
=
None
# token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
# input_mask: float32 Tensor in shape [bsz, len], the input mask.
# 0 for real tokens and 1 for padding.
# mems: a list of float32 Tensors in shape [bsz, mem_len, hidden_size], memory
# from previous batches. The length of the list equals num_hidden_layers.
# If None, no memory is used.
# perm_mask: float32 Tensor in shape [bsz, len, len].
# If perm_mask[k, i, j] = 0, i attend to j in batch k;
# if perm_mask[k, i, j] = 1, i does not attend to j in batch k.
# If None, each position attends to all the others.
# target_mapping: float32 Tensor in shape [bsz, num_predict, len].
# If target_mapping[k, i, j] = 1, the i-th predict in batch k is
# on the j-th token.
# Only used during pretraining for partial prediction.
# Set to None during finetuning.
# inp_q: float32 Tensor in shape [bsz, len].
# 1 for tokens with losses and 0 for tokens without losses.
# Only used during pretraining for two-stream attention.
# Set to None during finetuning.
lm_labels
=
None
lm_labels
=
None
is_impossible_labels
=
None
if
self
.
use_labels
:
if
self
.
use_labels
:
lm_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
lm_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
sequence_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
is_impossible_labels
=
ids_tensor
([
self
.
batch_size
],
2
).
float
()
config
=
XLNetConfig
(
config
=
XLNetConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
vocab_size_or_config_json_file
=
self
.
vocab_size
,
...
@@ -132,18 +118,23 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -132,18 +118,23 @@ class XLNetModelTest(unittest.TestCase):
same_length
=
self
.
same_length
,
same_length
=
self
.
same_length
,
reuse_len
=
self
.
reuse_len
,
reuse_len
=
self
.
reuse_len
,
bi_data
=
self
.
bi_data
,
bi_data
=
self
.
bi_data
,
initializer_range
=
self
.
initializer_range
)
initializer_range
=
self
.
initializer_range
,
num_labels
=
self
.
type_sequence_label_size
)
return
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
)
return
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
)
def
set_seed
(
self
):
def
set_seed
(
self
):
random
.
seed
(
self
.
seed
)
random
.
seed
(
self
.
seed
)
torch
.
manual_seed
(
self
.
seed
)
torch
.
manual_seed
(
self
.
seed
)
def
create_and_check_xlnet_base_model
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
):
def
create_and_check_xlnet_base_model
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
model
=
XLNetModel
(
config
)
model
=
XLNetModel
(
config
)
model
.
eval
()
model
.
eval
()
_
,
_
=
model
(
input_ids_1
,
input_mask
=
input_mask
)
_
,
_
=
model
(
input_ids_1
,
attention_mask
=
input_mask
)
_
,
_
=
model
(
input_ids_1
,
token_type_ids
=
segment_ids
)
_
,
_
=
model
(
input_ids_1
,
token_type_ids
=
segment_ids
)
outputs
,
mems_1
=
model
(
input_ids_1
)
outputs
,
mems_1
=
model
(
input_ids_1
)
...
@@ -159,7 +150,8 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -159,7 +150,8 @@ class XLNetModelTest(unittest.TestCase):
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
):
def
create_and_check_xlnet_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
model
=
XLNetLMHeadModel
(
config
)
model
=
XLNetLMHeadModel
(
config
)
model
.
eval
()
model
.
eval
()
...
@@ -198,7 +190,82 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -198,7 +190,82 @@ class XLNetModelTest(unittest.TestCase):
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2"
]),
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_commons
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
):
def
create_and_check_xlnet_qa
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
model
=
XLNetForQuestionAnswering
(
config
)
model
.
eval
()
outputs
=
model
(
input_ids_1
)
start_top_log_probs
,
start_top_index
,
end_top_log_probs
,
end_top_index
,
cls_logits
,
mems
=
outputs
outputs
=
model
(
input_ids_1
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
cls_index
=
sequence_labels
,
is_impossible
=
is_impossible_labels
,
p_mask
=
input_mask
)
outputs
=
model
(
input_ids_1
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
cls_index
=
sequence_labels
,
is_impossible
=
is_impossible_labels
)
total_loss
,
start_logits
,
end_logits
,
cls_logits
,
mems
=
outputs
outputs
=
model
(
input_ids_1
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
)
total_loss
,
start_logits
,
end_logits
,
mems
=
outputs
result
=
{
"loss"
:
total_loss
,
"start_logits"
:
start_logits
,
"end_logits"
:
end_logits
,
"cls_logits"
:
cls_logits
,
"mems"
:
mems
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"cls_logits"
].
size
()),
[
self
.
batch_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_sequence_classif
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
model
=
XLNetForSequenceClassification
(
config
)
model
.
eval
()
logits
,
mems_1
=
model
(
input_ids_1
)
loss
,
logits
,
mems_1
=
model
(
input_ids_1
,
labels
=
sequence_labels
)
result
=
{
"loss"
:
loss
,
"mems_1"
:
mems_1
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
type_sequence_label_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_commons
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
create_and_check_commons
(
self
,
config
,
inputs_dict
,
test_pruning
=
False
)
create_and_check_commons
(
self
,
config
,
inputs_dict
,
test_pruning
=
False
)
...
@@ -224,27 +291,19 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -224,27 +291,19 @@ class XLNetModelTest(unittest.TestCase):
tester
.
set_seed
()
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_xlnet_lm_head
(
*
config_and_inputs
)
tester
.
create_and_check_xlnet_lm_head
(
*
config_and_inputs
)
tester
.
set_seed
()
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_xlnet_commons
(
*
config_and_inputs
)
tester
.
create_and_check_xlnet_sequence_classif
(
*
config_and_inputs
)
@
classmethod
def
mask_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
"""Creates a tensor with padding on the right (0.0 for )."""
if
rng
is
None
:
rng
=
random
.
Random
()
total_dims
=
1
for
dim
in
shape
:
total_dims
*=
dim
values
=
[]
tester
.
set_seed
()
for
_
in
range
(
total_dims
):
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
values
.
append
(
rng
.
r
andin
t
(
0
,
vocab_size
-
1
)
)
tester
.
create_and_check_xlnet_qa
(
*
config_
and
_
in
puts
)
return
torch
.
tensor
(
data
=
values
,
dtype
=
torch
.
long
).
view
(
shape
).
contiguous
()
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_xlnet_commons
(
*
config_and_inputs
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment