Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
15b70338
Commit
15b70338
authored
Jul 04, 2019
by
thomwolf
Browse files
adding squad model to xlnet and xlm
parent
fbe04423
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
531 additions
and
326 deletions
+531
-326
pytorch_pretrained_bert/model_utils.py
pytorch_pretrained_bert/model_utils.py
+182
-15
pytorch_pretrained_bert/modeling_xlm.py
pytorch_pretrained_bert/modeling_xlm.py
+86
-149
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+68
-26
pytorch_pretrained_bert/tests/modeling_openai_test.py
pytorch_pretrained_bert/tests/modeling_openai_test.py
+0
-4
pytorch_pretrained_bert/tests/modeling_xlm_test.py
pytorch_pretrained_bert/tests/modeling_xlm_test.py
+94
-90
pytorch_pretrained_bert/tests/modeling_xlnet_test.py
pytorch_pretrained_bert/tests/modeling_xlnet_test.py
+101
-42
No files found.
pytorch_pretrained_bert/model_utils.py
View file @
15b70338
...
@@ -25,7 +25,7 @@ from io import open
...
@@ -25,7 +25,7 @@ from io import open
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
,
functional
as
F
from
.file_utils
import
cached_path
from
.file_utils
import
cached_path
...
@@ -301,22 +301,189 @@ class Conv1D(nn.Module):
...
@@ -301,22 +301,189 @@ class Conv1D(nn.Module):
return
x
return
x
class
SequenceSummary
(
nn
.
Module
):
class
PoolerStartLogits
(
nn
.
Module
):
""" Compute SQuAD start_logits from sequence hidden states. """
def
__init__
(
self
,
config
):
super
(
PoolerStartLogits
,
self
).
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
1
)
def
forward
(
self
,
hidden_states
,
p_mask
=
None
):
""" Args:
`p_mask`: [optional] invalid position mask such as query and special symbols (PAD, SEP, CLS)
shape [batch_size, seq_len]. 1.0 means token should be masked.
"""
x
=
self
.
dense
(
hidden_states
).
squeeze
(
-
1
)
if
p_mask
is
not
None
:
x
=
x
*
(
1
-
p_mask
)
-
1e30
*
p_mask
return
x
class
PoolerEndLogits
(
nn
.
Module
):
""" Compute SQuAD end_logits from sequence hidden states and start token hidden state.
"""
def
__init__
(
self
,
config
):
super
(
PoolerEndLogits
,
self
).
__init__
()
self
.
dense_0
=
nn
.
Linear
(
config
.
hidden_size
*
2
,
config
.
hidden_size
)
self
.
activation
=
nn
.
Tanh
()
self
.
LayerNorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
dense_1
=
nn
.
Linear
(
config
.
hidden_size
,
1
)
def
forward
(
self
,
hidden_states
,
start_states
=
None
,
start_positions
=
None
,
p_mask
=
None
):
""" Args:
One of start_states, start_positions should be not None. If both are set, start_positions overrides start_states.
`start_states`: hidden states of the first tokens for the labeled span: torch.LongTensor of shape identical to hidden_states.
`start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
`p_mask`: [optional] invalid position mask such as query and special symbols (PAD, SEP, CLS)
shape [batch_size, seq_len]. 1.0 means token should be masked.
"""
slen
,
hsz
=
hidden_states
.
shape
[
-
2
:]
assert
start_states
is
not
None
or
start_positions
is
not
None
,
"One of start_states, start_positions should be not None"
if
start_positions
is
not
None
:
start_positions
=
start_positions
[:,
None
,
None
].
expand
(
-
1
,
-
1
,
hsz
)
# shape (bsz, 1, hsz)
start_states
=
hidden_states
.
gather
(
-
2
,
start_positions
)
# shape (bsz, 1, hsz)
start_states
=
start_states
.
expand
(
-
1
,
slen
,
-
1
)
# shape (bsz, slen, hsz)
x
=
self
.
dense_0
(
torch
.
cat
([
hidden_states
,
start_states
],
dim
=-
1
))
x
=
self
.
activation
(
x
)
x
=
self
.
LayerNorm
(
x
)
x
=
self
.
dense_1
(
x
).
squeeze
(
-
1
)
if
p_mask
is
not
None
:
x
=
x
*
(
1
-
p_mask
)
-
1e30
*
p_mask
return
x
class
PoolerAnswerClass
(
nn
.
Module
):
""" Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """
def
__init__
(
self
,
config
):
super
(
PoolerAnswerClass
,
self
).
__init__
()
self
.
dense_0
=
nn
.
Linear
(
config
.
hidden_size
*
2
,
config
.
hidden_size
)
self
.
activation
=
nn
.
Tanh
()
self
.
dense_1
=
nn
.
Linear
(
config
.
hidden_size
,
1
,
bias
=
False
)
def
forward
(
self
,
hidden_states
,
start_states
=
None
,
start_positions
=
None
,
cls_index
=
None
):
""" Args:
One of start_states, start_positions should be not None. If both are set, start_positions overrides start_states.
`start_states`: hidden states of the first tokens for the labeled span: torch.LongTensor of shape identical to hidden_states.
`start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
`cls_index`: position of the CLS token: torch.LongTensor of shape [batch_size]. If None, take the last token.
# note(zhiliny): no dependency on end_feature so that we can obtain one single `cls_logits` for each sample
"""
slen
,
hsz
=
hidden_states
.
shape
[
-
2
:]
assert
start_states
is
not
None
or
start_positions
is
not
None
,
"One of start_states, start_positions should be not None"
if
start_positions
is
not
None
:
start_positions
=
start_positions
[:,
None
,
None
].
expand
(
-
1
,
-
1
,
hsz
)
# shape (bsz, 1, hsz)
start_states
=
hidden_states
.
gather
(
-
2
,
start_positions
).
squeeze
(
-
2
)
# shape (bsz, hsz)
if
cls_index
is
not
None
:
cls_index
=
cls_index
[:,
None
,
None
].
expand
(
-
1
,
-
1
,
hsz
)
# shape (bsz, 1, hsz)
cls_token_state
=
hidden_states
.
gather
(
-
2
,
cls_index
).
squeeze
(
-
2
)
# shape (bsz, hsz)
else
:
cls_token_state
=
hidden_states
[:,
-
1
,
:]
# shape (bsz, hsz)
x
=
self
.
dense_0
(
torch
.
cat
([
start_states
,
cls_token_state
],
dim
=-
1
))
x
=
self
.
activation
(
x
)
x
=
self
.
dense_1
(
x
).
squeeze
(
-
1
)
return
x
class
SQuADHead
(
nn
.
Module
):
""" A SQuAD head inspired by XLNet.
Compute
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
""" Compute a single vector summary of a sequence hidden states according to various possibilities:
super
(
SQuADHead
,
self
).
__init__
()
Args of the config class:
self
.
start_n_top
=
config
.
start_n_top
summary_type:
self
.
end_n_top
=
config
.
end_n_top
- 'last' => [default] take the last token hidden state (like XLNet)
- 'first' => take the first token hidden state (like Bert)
self
.
start_logits
=
PoolerStartLogits
(
config
)
- 'mean' => take the mean of all tokens hidden states
self
.
end_logits
=
PoolerEndLogits
(
config
)
- 'token_ids' => supply a Tensor of classification token indices (GPT/GPT-2)
self
.
answer_class
=
PoolerAnswerClass
(
config
)
- 'attn' => Not implemented now, use multi-head attention
summary_use_proj: Add a projection after the vector extraction
def
forward
(
self
,
hidden_states
,
start_positions
=
None
,
end_positions
=
None
,
summary_num_classes: If > 0: the projection outputs to n classes (otherwise to hidden_size)
cls_index
=
None
,
is_impossible
=
None
,
p_mask
=
None
):
summary_activation:
""" hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer.
'tanh' => add a tanh activation to the output
None => no activation
"""
"""
outputs
=
()
start_logits
=
self
.
start_logits
(
hidden_states
,
p_mask
)
if
start_positions
is
not
None
and
end_positions
is
not
None
:
# If we are on multi-GPU, let's remove the dimension added by batch splitting
for
x
in
(
start_positions
,
end_positions
,
cls_index
,
is_impossible
):
if
x
is
not
None
and
x
.
dim
()
>
1
:
x
.
squeeze_
(
-
1
)
# during training, compute the end logits based on the ground truth of the start position
end_logits
=
self
.
end_logits
(
hidden_states
,
start_positions
=
start_positions
,
p_mask
=
p_mask
)
loss_fct
=
CrossEntropyLoss
()
start_loss
=
loss_fct
(
start_logits
,
start_positions
)
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2
if
cls_index
is
not
None
and
is_impossible
is
not
None
:
# Predict answerability from the representation of CLS and START
cls_logits
=
self
.
answer_class
(
hidden_states
,
start_positions
=
start_positions
,
cls_index
=
cls_index
)
loss_fct_cls
=
nn
.
BCEWithLogitsLoss
()
cls_loss
=
loss_fct_cls
(
cls_logits
,
is_impossible
)
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
total_loss
+=
cls_loss
*
0.5
outputs
=
(
total_loss
,
start_logits
,
end_logits
,
cls_logits
)
+
outputs
else
:
outputs
=
(
total_loss
,
start_logits
,
end_logits
)
+
outputs
else
:
# during inference, compute the end logits based on beam search
bsz
,
slen
,
hsz
=
hidden_states
.
size
()
start_log_probs
=
F
.
softmax
(
start_logits
,
dim
=-
1
)
# shape (bsz, slen)
start_top_log_probs
,
start_top_index
=
torch
.
topk
(
start_log_probs
,
self
.
start_n_top
,
dim
=-
1
)
# shape (bsz, start_n_top)
start_top_index
=
start_top_index
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
hsz
)
# shape (bsz, start_n_top, hsz)
start_states
=
torch
.
gather
(
hidden_states
,
-
2
,
start_top_index
)
# shape (bsz, start_n_top, hsz)
start_states
=
start_states
.
unsqueeze
(
1
).
expand
(
-
1
,
slen
,
-
1
,
-
1
)
# shape (bsz, slen, start_n_top, hsz)
hidden_states_expanded
=
hidden_states
.
unsqueeze
(
2
).
expand_as
(
start_states
)
# shape (bsz, slen, start_n_top, hsz)
p_mask
=
p_mask
.
unsqueeze
(
-
1
)
if
p_mask
is
not
None
else
None
end_logits
=
self
.
end_logits
(
hidden_states_expanded
,
start_states
=
start_states
,
p_mask
=
p_mask
)
end_log_probs
=
F
.
softmax
(
end_logits
,
dim
=
1
)
# shape (bsz, slen, start_n_top)
end_top_log_probs
,
end_top_index
=
torch
.
topk
(
end_log_probs
,
self
.
end_n_top
,
dim
=
1
)
# shape (bsz, end_n_top, start_n_top)
end_top_log_probs
=
end_top_log_probs
.
view
(
-
1
,
self
.
start_n_top
*
self
.
end_n_top
)
end_top_index
=
end_top_index
.
view
(
-
1
,
self
.
start_n_top
*
self
.
end_n_top
)
start_states
=
torch
.
einsum
(
"blh,bl->bh"
,
hidden_states
,
start_log_probs
)
cls_logits
=
self
.
answer_class
(
hidden_states
,
start_states
=
start_states
,
cls_index
=
cls_index
)
outputs
=
(
start_top_log_probs
,
start_top_index
,
end_top_log_probs
,
end_top_index
,
cls_logits
)
+
outputs
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
# or (if labels are provided) total_loss, start_logits, end_logits, (cls_logits)
return
outputs
class
SequenceSummary
(
nn
.
Module
):
""" Compute a single vector summary of a sequence hidden states according to various possibilities:
Args of the config class:
summary_type:
- 'last' => [default] take the last token hidden state (like XLNet)
- 'first' => take the first token hidden state (like Bert)
- 'mean' => take the mean of all tokens hidden states
- 'token_ids' => supply a Tensor of classification token indices (GPT/GPT-2)
- 'attn' => Not implemented now, use multi-head attention
summary_use_proj: Add a projection after the vector extraction
summary_num_classes: If > 0: the projection outputs to n classes (otherwise to hidden_size)
summary_activation:
'tanh' => add a tanh activation to the output
None => no activation
"""
def
__init__
(
self
,
config
):
super
(
SequenceSummary
,
self
).
__init__
()
super
(
SequenceSummary
,
self
).
__init__
()
self
.
summary_type
=
config
.
summary_type
if
hasattr
(
config
,
'summary_use_proj'
)
else
'last'
self
.
summary_type
=
config
.
summary_type
if
hasattr
(
config
,
'summary_use_proj'
)
else
'last'
...
...
pytorch_pretrained_bert/modeling_xlm.py
View file @
15b70338
...
@@ -35,7 +35,8 @@ from torch.nn import functional as F
...
@@ -35,7 +35,8 @@ from torch.nn import functional as F
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
.file_utils
import
cached_path
from
.file_utils
import
cached_path
from
.model_utils
import
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
,
prune_linear_layer
from
.model_utils
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
,
prune_linear_layer
,
SequenceSummary
,
SQuADHead
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -67,15 +68,23 @@ class XLMConfig(PretrainedConfig):
...
@@ -67,15 +68,23 @@ class XLMConfig(PretrainedConfig):
n_langs
=
1
,
n_langs
=
1
,
max_position_embeddings
=
512
,
max_position_embeddings
=
512
,
embed_init_std
=
2048
**
-
0.5
,
embed_init_std
=
2048
**
-
0.5
,
layer_norm_eps
=
1e-12
,
init_std
=
0.02
,
init_std
=
0.02
,
summary_type
=
"last"
,
use_proj
=
True
,
bos_index
=
0
,
bos_index
=
0
,
eos_index
=
1
,
eos_index
=
1
,
pad_index
=
2
,
pad_index
=
2
,
unk_index
=
3
,
unk_index
=
3
,
mask_index
=
5
,
mask_index
=
5
,
is_encoder
=
True
,
is_encoder
=
True
,
finetuning_task
=
None
,
num_labels
=
2
,
summary_type
=
'last'
,
summary_use_proj
=
True
,
summary_activation
=
'tanh'
,
summary_dropout
=
0.1
,
start_n_top
=
5
,
end_n_top
=
5
,
**
kwargs
):
**
kwargs
):
"""Constructs XLMConfig.
"""Constructs XLMConfig.
...
@@ -140,8 +149,7 @@ class XLMConfig(PretrainedConfig):
...
@@ -140,8 +149,7 @@ class XLMConfig(PretrainedConfig):
self
.
causal
=
causal
self
.
causal
=
causal
self
.
asm
=
asm
self
.
asm
=
asm
self
.
n_langs
=
n_langs
self
.
n_langs
=
n_langs
self
.
summary_type
=
summary_type
self
.
layer_norm_eps
=
layer_norm_eps
self
.
use_proj
=
use_proj
self
.
bos_index
=
bos_index
self
.
bos_index
=
bos_index
self
.
eos_index
=
eos_index
self
.
eos_index
=
eos_index
self
.
pad_index
=
pad_index
self
.
pad_index
=
pad_index
...
@@ -151,6 +159,14 @@ class XLMConfig(PretrainedConfig):
...
@@ -151,6 +159,14 @@ class XLMConfig(PretrainedConfig):
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
self
.
embed_init_std
=
embed_init_std
self
.
embed_init_std
=
embed_init_std
self
.
init_std
=
init_std
self
.
init_std
=
init_std
self
.
finetuning_task
=
finetuning_task
self
.
num_labels
=
num_labels
self
.
summary_type
=
summary_type
self
.
summary_use_proj
=
summary_use_proj
self
.
summary_activation
=
summary_activation
self
.
summary_dropout
=
summary_dropout
self
.
start_n_top
=
start_n_top
self
.
end_n_top
=
end_n_top
else
:
else
:
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
"or the path to a pretrained model config file (str)"
)
...
@@ -172,26 +188,6 @@ class XLMConfig(PretrainedConfig):
...
@@ -172,26 +188,6 @@ class XLMConfig(PretrainedConfig):
return
self
.
n_layers
return
self
.
n_layers
def
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
=
None
,
config
=
None
):
m
=
nn
.
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
=
padding_idx
)
if
config
is
not
None
and
config
.
embed_init_std
is
not
None
:
nn
.
init
.
normal_
(
m
.
weight
,
mean
=
0
,
std
=
config
.
embed_init_std
)
if
padding_idx
is
not
None
:
nn
.
init
.
constant_
(
m
.
weight
[
padding_idx
],
0
)
return
m
def
Linear
(
in_features
,
out_features
,
bias
=
True
,
config
=
None
):
m
=
nn
.
Linear
(
in_features
,
out_features
,
bias
)
if
config
is
not
None
and
config
.
init_std
is
not
None
:
nn
.
init
.
normal_
(
m
.
weight
,
mean
=
0
,
std
=
config
.
init_std
)
if
bias
:
nn
.
init
.
constant_
(
m
.
bias
,
0.
)
# nn.init.xavier_uniform_(m.weight)
# nn.init.constant_(m.bias, 0.)
return
m
def
create_sinusoidal_embeddings
(
n_pos
,
dim
,
out
):
def
create_sinusoidal_embeddings
(
n_pos
,
dim
,
out
):
position_enc
=
np
.
array
([
position_enc
=
np
.
array
([
[
pos
/
np
.
power
(
10000
,
2
*
(
j
//
2
)
/
dim
)
for
j
in
range
(
dim
)]
[
pos
/
np
.
power
(
10000
,
2
*
(
j
//
2
)
/
dim
)
for
j
in
range
(
dim
)]
...
@@ -244,7 +240,7 @@ class MultiHeadAttention(nn.Module):
...
@@ -244,7 +240,7 @@ class MultiHeadAttention(nn.Module):
NEW_ID
=
itertools
.
count
()
NEW_ID
=
itertools
.
count
()
def
__init__
(
self
,
n_heads
,
dim
,
config
):
def
__init__
(
self
,
n_heads
,
dim
,
config
):
super
().
__init__
()
super
(
MultiHeadAttention
,
self
).
__init__
()
self
.
layer_id
=
next
(
MultiHeadAttention
.
NEW_ID
)
self
.
layer_id
=
next
(
MultiHeadAttention
.
NEW_ID
)
self
.
output_attentions
=
config
.
output_attentions
self
.
output_attentions
=
config
.
output_attentions
self
.
dim
=
dim
self
.
dim
=
dim
...
@@ -252,10 +248,10 @@ class MultiHeadAttention(nn.Module):
...
@@ -252,10 +248,10 @@ class MultiHeadAttention(nn.Module):
self
.
dropout
=
config
.
attention_dropout
self
.
dropout
=
config
.
attention_dropout
assert
self
.
dim
%
self
.
n_heads
==
0
assert
self
.
dim
%
self
.
n_heads
==
0
self
.
q_lin
=
Linear
(
dim
,
dim
,
config
=
config
)
self
.
q_lin
=
nn
.
Linear
(
dim
,
dim
)
self
.
k_lin
=
Linear
(
dim
,
dim
,
config
=
config
)
self
.
k_lin
=
nn
.
Linear
(
dim
,
dim
)
self
.
v_lin
=
Linear
(
dim
,
dim
,
config
=
config
)
self
.
v_lin
=
nn
.
Linear
(
dim
,
dim
)
self
.
out_lin
=
Linear
(
dim
,
dim
,
config
=
config
)
self
.
out_lin
=
nn
.
Linear
(
dim
,
dim
)
def
prune_heads
(
self
,
heads
):
def
prune_heads
(
self
,
heads
):
attention_head_size
=
self
.
dim
//
self
.
n_heads
attention_head_size
=
self
.
dim
//
self
.
n_heads
...
@@ -342,10 +338,10 @@ class MultiHeadAttention(nn.Module):
...
@@ -342,10 +338,10 @@ class MultiHeadAttention(nn.Module):
class
TransformerFFN
(
nn
.
Module
):
class
TransformerFFN
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
dim_hidden
,
out_dim
,
config
):
def
__init__
(
self
,
in_dim
,
dim_hidden
,
out_dim
,
config
):
super
().
__init__
()
super
(
TransformerFFN
,
self
).
__init__
()
self
.
dropout
=
config
.
dropout
self
.
dropout
=
config
.
dropout
self
.
lin1
=
Linear
(
in_dim
,
dim_hidden
,
config
=
config
)
self
.
lin1
=
nn
.
Linear
(
in_dim
,
dim_hidden
)
self
.
lin2
=
Linear
(
dim_hidden
,
out_dim
,
config
=
config
)
self
.
lin2
=
nn
.
Linear
(
dim_hidden
,
out_dim
)
self
.
act
=
gelu
if
config
.
gelu_activation
else
F
.
relu
self
.
act
=
gelu
if
config
.
gelu_activation
else
F
.
relu
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
...
@@ -363,17 +359,21 @@ class XLMPreTrainedModel(PreTrainedModel):
...
@@ -363,17 +359,21 @@ class XLMPreTrainedModel(PreTrainedModel):
config_class
=
XLMConfig
config_class
=
XLMConfig
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights
=
None
load_tf_weights
=
None
base_model_prefix
=
"
xlm
"
base_model_prefix
=
"
transformer
"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
XLMPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
super
(
XLMPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
def
init_weights
(
self
,
module
):
""" Initialize the weights.
""" Initialize the weights. """
"""
if
isinstance
(
module
,
nn
.
Embedding
):
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Embedding
)):
if
self
.
config
is
not
None
and
self
.
config
.
embed_init_std
is
not
None
:
# Weights are initialized in module instantiation (see above)
nn
.
init
.
normal_
(
module
.
weight
,
mean
=
0
,
std
=
self
.
config
.
embed_init_std
)
pass
if
isinstance
(
module
,
nn
.
Linear
):
if
self
.
config
is
not
None
and
self
.
config
.
init_std
is
not
None
:
nn
.
init
.
normal_
(
module
.
weight
,
mean
=
0
,
std
=
self
.
config
.
init_std
)
if
hasattr
(
module
,
'bias'
)
and
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
0.
)
if
isinstance
(
module
,
nn
.
LayerNorm
):
if
isinstance
(
module
,
nn
.
LayerNorm
):
module
.
bias
.
data
.
zero_
()
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
module
.
weight
.
data
.
fill_
(
1.0
)
...
@@ -471,13 +471,13 @@ class XLMModel(XLMPreTrainedModel):
...
@@ -471,13 +471,13 @@ class XLMModel(XLMPreTrainedModel):
assert
self
.
dim
%
self
.
n_heads
==
0
,
'transformer dim must be a multiple of n_heads'
assert
self
.
dim
%
self
.
n_heads
==
0
,
'transformer dim must be a multiple of n_heads'
# embeddings
# embeddings
self
.
position_embeddings
=
Embedding
(
config
.
max_position_embeddings
,
self
.
dim
,
config
=
config
)
self
.
position_embeddings
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
self
.
dim
)
if
config
.
sinusoidal_embeddings
:
if
config
.
sinusoidal_embeddings
:
create_sinusoidal_embeddings
(
config
.
max_position_embeddings
,
self
.
dim
,
out
=
self
.
position_embeddings
.
weight
)
create_sinusoidal_embeddings
(
config
.
max_position_embeddings
,
self
.
dim
,
out
=
self
.
position_embeddings
.
weight
)
if
config
.
n_langs
>
1
:
if
config
.
n_langs
>
1
:
self
.
lang_embeddings
=
Embedding
(
self
.
n_langs
,
self
.
dim
,
config
=
config
)
self
.
lang_embeddings
=
nn
.
Embedding
(
self
.
n_langs
,
self
.
dim
)
self
.
embeddings
=
Embedding
(
self
.
n_words
,
self
.
dim
,
padding_idx
=
self
.
pad_index
,
config
=
config
)
self
.
embeddings
=
nn
.
Embedding
(
self
.
n_words
,
self
.
dim
,
padding_idx
=
self
.
pad_index
)
self
.
layer_norm_emb
=
nn
.
LayerNorm
(
self
.
dim
,
eps
=
1e-12
)
self
.
layer_norm_emb
=
nn
.
LayerNorm
(
self
.
dim
,
eps
=
config
.
layer_norm_eps
)
# transformer layers
# transformer layers
self
.
attentions
=
nn
.
ModuleList
()
self
.
attentions
=
nn
.
ModuleList
()
...
@@ -490,12 +490,14 @@ class XLMModel(XLMPreTrainedModel):
...
@@ -490,12 +490,14 @@ class XLMModel(XLMPreTrainedModel):
for
_
in
range
(
self
.
n_layers
):
for
_
in
range
(
self
.
n_layers
):
self
.
attentions
.
append
(
MultiHeadAttention
(
self
.
n_heads
,
self
.
dim
,
config
=
config
))
self
.
attentions
.
append
(
MultiHeadAttention
(
self
.
n_heads
,
self
.
dim
,
config
=
config
))
self
.
layer_norm1
.
append
(
nn
.
LayerNorm
(
self
.
dim
,
eps
=
1e-12
))
self
.
layer_norm1
.
append
(
nn
.
LayerNorm
(
self
.
dim
,
eps
=
config
.
layer_norm_eps
))
# if self.is_decoder:
# if self.is_decoder:
# self.layer_norm15.append(nn.LayerNorm(self.dim, eps=
1e-12
))
# self.layer_norm15.append(nn.LayerNorm(self.dim, eps=
config.layer_norm_eps
))
# self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
# self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
self
.
ffns
.
append
(
TransformerFFN
(
self
.
dim
,
self
.
hidden_dim
,
self
.
dim
,
config
=
config
))
self
.
ffns
.
append
(
TransformerFFN
(
self
.
dim
,
self
.
hidden_dim
,
self
.
dim
,
config
=
config
))
self
.
layer_norm2
.
append
(
nn
.
LayerNorm
(
self
.
dim
,
eps
=
1e-12
))
self
.
layer_norm2
.
append
(
nn
.
LayerNorm
(
self
.
dim
,
eps
=
config
.
layer_norm_eps
))
self
.
apply
(
self
.
init_weights
)
def
_prune_heads
(
self
,
heads_to_prune
):
def
_prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
""" Prunes heads of the model.
...
@@ -636,14 +638,14 @@ class XLMPredLayer(nn.Module):
...
@@ -636,14 +638,14 @@ class XLMPredLayer(nn.Module):
Prediction layer (cross_entropy or adaptive_softmax).
Prediction layer (cross_entropy or adaptive_softmax).
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
()
super
(
XLMPredLayer
,
self
).
__init__
()
self
.
asm
=
config
.
asm
self
.
asm
=
config
.
asm
self
.
n_words
=
config
.
n_words
self
.
n_words
=
config
.
n_words
self
.
pad_index
=
config
.
pad_index
self
.
pad_index
=
config
.
pad_index
dim
=
config
.
emb_dim
dim
=
config
.
emb_dim
if
config
.
asm
is
False
:
if
config
.
asm
is
False
:
self
.
proj
=
Linear
(
dim
,
config
.
n_words
,
bias
=
True
)
self
.
proj
=
nn
.
Linear
(
dim
,
config
.
n_words
,
bias
=
True
)
else
:
else
:
self
.
proj
=
nn
.
AdaptiveLogSoftmaxWithLoss
(
self
.
proj
=
nn
.
AdaptiveLogSoftmaxWithLoss
(
in_features
=
dim
,
in_features
=
dim
,
...
@@ -653,28 +655,24 @@ class XLMPredLayer(nn.Module):
...
@@ -653,28 +655,24 @@ class XLMPredLayer(nn.Module):
head_bias
=
True
,
# default is False
head_bias
=
True
,
# default is False
)
)
def
forward
(
self
,
x
,
y
,
get_scores
=
False
):
def
forward
(
self
,
x
,
y
=
None
):
""" Compute the loss, and optionally the scores.
"""
"""
Compute the loss, and optionally the scores.
outputs
=
()
"""
assert
(
y
==
self
.
pad_index
).
sum
().
item
()
==
0
if
self
.
asm
is
False
:
if
self
.
asm
is
False
:
scores
=
self
.
proj
(
x
).
view
(
-
1
,
self
.
n_words
)
scores
=
self
.
proj
(
x
).
view
(
-
1
,
self
.
n_words
)
loss
=
F
.
cross_entropy
(
scores
,
y
,
reduction
=
'elementwise_mean'
)
outputs
=
(
scores
,)
+
outputs
if
y
is
not
None
:
loss
=
F
.
cross_entropy
(
scores
,
y
,
reduction
=
'elementwise_mean'
)
outputs
=
(
loss
,)
+
outputs
else
:
else
:
_
,
loss
=
self
.
proj
(
x
,
y
)
scores
=
self
.
proj
.
log_prob
(
x
)
scores
=
self
.
proj
.
log_prob
(
x
)
if
get_scores
else
None
outputs
=
(
scores
,)
+
outputs
if
y
is
not
None
:
return
scores
,
loss
_
,
loss
=
self
.
proj
(
x
,
y
)
outputs
=
(
loss
,)
+
outputs
def
get_scores
(
self
,
x
):
"""
Compute scores.
"""
assert
x
.
dim
()
==
2
return
self
.
proj
.
log_prob
(
x
)
if
self
.
asm
else
self
.
proj
(
x
)
return
outputs
class
XLMWithLMHeadModel
(
XLMPreTrainedModel
):
class
XLMWithLMHeadModel
(
XLMPreTrainedModel
):
...
@@ -731,6 +729,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
...
@@ -731,6 +729,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
XLMWithLMHeadModel
,
self
).
__init__
(
config
)
super
(
XLMWithLMHeadModel
,
self
).
__init__
(
config
)
self
.
torchscript
=
config
.
torchscript
self
.
transformer
=
XLMModel
(
config
)
self
.
transformer
=
XLMModel
(
config
)
self
.
pred_layer
=
XLMPredLayer
(
config
)
self
.
pred_layer
=
XLMPredLayer
(
config
)
...
@@ -741,7 +740,10 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
...
@@ -741,7 +740,10 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
def
tie_weights
(
self
):
def
tie_weights
(
self
):
""" Make sure we are sharing the embeddings
""" Make sure we are sharing the embeddings
"""
"""
self
.
pred_layer
.
proj
.
weight
=
self
.
transformer
.
embeddings
.
weight
if
self
.
torchscript
:
self
.
pred_layer
.
proj
.
weight
=
nn
.
Parameter
(
self
.
transformer
.
embeddings
.
weight
.
clone
())
else
:
self
.
pred_layer
.
proj
.
weight
=
self
.
transformer
.
embeddings
.
weight
def
forward
(
self
,
input_ids
,
lengths
=
None
,
positions
=
None
,
langs
=
None
,
token_type_ids
=
None
,
def
forward
(
self
,
input_ids
,
lengths
=
None
,
positions
=
None
,
langs
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
cache
=
None
,
labels
=
None
,
head_mask
=
None
):
attention_mask
=
None
,
cache
=
None
,
labels
=
None
,
head_mask
=
None
):
...
@@ -775,55 +777,12 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
...
@@ -775,55 +777,12 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
langs
=
langs
,
attention_mask
=
attention_mask
,
cache
=
cache
,
head_mask
=
head_mask
)
langs
=
langs
,
attention_mask
=
attention_mask
,
cache
=
cache
,
head_mask
=
head_mask
)
output
=
transformer_outputs
[
0
]
output
=
transformer_outputs
[
0
]
logits
=
self
.
pred_layer
(
output
,
labels
)
outputs
=
self
.
pred_layer
(
output
,
labels
)
outputs
=
outputs
+
transformer_outputs
[
1
:]
# Keep new_mems and attention/hidden states if they are here
outputs
=
transformer_outputs
[
1
:]
# Keep new_mems and attention/hidden states if they are here
if
labels
is
not
None
:
# Flatten the tokens
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
logits
.
view
(
-
1
,
logits
.
size
(
-
1
)),
labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
[
logits
]
+
outputs
return
outputs
return
outputs
class
XLMSequenceSummary
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
XLMSequenceSummary
,
self
).
__init__
()
self
.
summary_type
=
config
.
summary_type
if
config
.
use_proj
:
self
.
summary
=
nn
.
Linear
(
config
.
d_model
,
config
.
d_model
)
else
:
self
.
summary
=
None
if
config
.
summary_type
==
'attn'
:
# We should use a standard multi-head attention module with absolute positional embedding for that.
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
raise
NotImplementedError
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
self
.
activation
=
nn
.
Tanh
()
def
forward
(
self
,
hidden_states
):
""" hidden_states: float Tensor in shape [bsz, seq_len, d_model], the hidden-states of the last layer."""
if
self
.
summary_type
==
'last'
:
output
=
hidden_states
[:,
-
1
]
elif
self
.
summary_type
==
'first'
:
output
=
hidden_states
[:,
0
]
elif
self
.
summary_type
==
'mean'
:
output
=
hidden_states
.
mean
(
dim
=
1
)
elif
summary_type
==
'attn'
:
raise
NotImplementedError
output
=
self
.
summary
(
output
)
output
=
self
.
activation
(
output
)
output
=
self
.
dropout
(
output
)
return
output
class
XLMForSequenceClassification
(
XLMPreTrainedModel
):
class
XLMForSequenceClassification
(
XLMPreTrainedModel
):
"""XLM model ("XLM: Generalized Autoregressive Pretraining for Language Understanding").
"""XLM model ("XLM: Generalized Autoregressive Pretraining for Language Understanding").
...
@@ -890,15 +849,15 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
...
@@ -890,15 +849,15 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
XLMForSequenceClassification
,
self
).
__init__
(
config
)
super
(
XLMForSequenceClassification
,
self
).
__init__
(
config
)
self
.
num_labels
=
config
.
num_labels
self
.
transformer
=
XLMModel
(
config
)
self
.
transformer
=
XLMModel
(
config
)
self
.
sequence_summary
=
XLMSequenceSummary
(
config
)
self
.
sequence_summary
=
SequenceSummary
(
config
)
self
.
logits_proj
=
nn
.
Linear
(
config
.
d_model
,
config
.
num_labels
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
lengths
=
None
,
positions
=
None
,
langs
=
None
,
attention_mask
=
None
,
def
forward
(
self
,
input_ids
,
lengths
=
None
,
positions
=
None
,
langs
=
None
,
token_type_ids
=
None
,
cache
=
None
,
labels
=
None
,
head_mask
=
None
):
attention_mask
=
None
,
cache
=
None
,
labels
=
None
,
head_mask
=
None
):
"""
"""
Args:
Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
...
@@ -930,10 +889,9 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
...
@@ -930,10 +889,9 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
langs
=
langs
,
attention_mask
=
attention_mask
,
cache
=
cache
,
head_mask
=
head_mask
)
langs
=
langs
,
attention_mask
=
attention_mask
,
cache
=
cache
,
head_mask
=
head_mask
)
output
=
transformer_outputs
[
0
]
output
=
transformer_outputs
[
0
]
output
=
self
.
sequence_summary
(
output
)
logits
=
self
.
sequence_summary
(
output
)
logits
=
self
.
logits_proj
(
output
)
outputs
=
transformer_outputs
[
1
:]
# Keep new_mems and attention/hidden states if they are here
outputs
=
(
logits
,)
+
transformer_outputs
[
1
:]
# Keep new_mems and attention/hidden states if they are here
if
labels
is
not
None
:
if
labels
is
not
None
:
if
self
.
num_labels
==
1
:
if
self
.
num_labels
==
1
:
...
@@ -943,9 +901,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
...
@@ -943,9 +901,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
else
:
else
:
loss_fct
=
CrossEntropyLoss
()
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
outputs
=
[
logits
]
+
outputs
return
outputs
return
outputs
...
@@ -1010,41 +966,22 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
...
@@ -1010,41 +966,22 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
super
(
XLMForQuestionAnswering
,
self
).
__init__
(
config
)
super
(
XLMForQuestionAnswering
,
self
).
__init__
(
config
)
self
.
transformer
=
XLMModel
(
config
)
self
.
transformer
=
XLMModel
(
config
)
self
.
qa_outputs
=
nn
.
Lin
ea
r
(
config
.
hidden_size
,
config
.
num_labels
)
self
.
qa_outputs
=
SQuADH
ea
d
(
config
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
lengths
=
None
,
positions
=
None
,
langs
=
None
,
attention_mask
=
None
,
cache
=
None
,
def
forward
(
self
,
input_ids
,
lengths
=
None
,
positions
=
None
,
langs
=
None
,
token_type_ids
=
None
,
labels
=
None
,
head_mask
=
None
):
attention_mask
=
None
,
cache
=
None
,
start_positions
=
None
,
end_positions
=
None
,
cls_index
=
None
,
is_impossible
=
None
,
p_mask
=
None
,
head_mask
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
lengths
=
lengths
,
positions
=
positions
,
token_type_ids
=
token_type_ids
,
transformer_outputs
=
self
.
transformer
(
input_ids
,
lengths
=
lengths
,
positions
=
positions
,
token_type_ids
=
token_type_ids
,
langs
=
langs
,
attention_mask
=
attention_mask
,
cache
=
cache
,
head_mask
=
head_mask
)
langs
=
langs
,
attention_mask
=
attention_mask
,
cache
=
cache
,
head_mask
=
head_mask
)
output
=
transformer_outputs
[
0
]
output
=
transformer_outputs
[
0
]
logits
=
self
.
qa_outputs
(
output
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
outputs
=
self
.
qa_outputs
(
output
,
start_positions
=
start_positions
,
end_positions
=
end_positions
,
start_logits
=
start_logits
.
squeeze
(
-
1
)
cls_index
=
cls_index
,
is_impossible
=
is_impossible
,
p_mask
=
p_mask
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
outputs
=
outputs
+
transformer_outputs
[
1
:]
# Keep new_mems and attention/hidden states if they are here
outputs
=
transformer_outputs
[
1
:]
# Keep new_mems and attention/hidden states if they are here
if
start_positions
is
not
None
and
end_positions
is
not
None
:
# If we are on multi-GPU, split add a dimension
if
len
(
start_positions
.
size
())
>
1
:
start_positions
=
start_positions
.
squeeze
(
-
1
)
if
len
(
end_positions
.
size
())
>
1
:
end_positions
=
end_positions
.
squeeze
(
-
1
)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index
=
start_logits
.
size
(
1
)
start_positions
.
clamp_
(
0
,
ignored_index
)
end_positions
.
clamp_
(
0
,
ignored_index
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=
ignored_index
)
start_loss
=
loss_fct
(
start_logits
,
start_positions
)
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2
outputs
=
[
total_loss
]
+
outputs
outputs
=
[
start_logits
,
end_logits
]
+
outputs
return
outputs
return
outputs
pytorch_pretrained_bert/modeling_xlnet.py
View file @
15b70338
...
@@ -32,8 +32,8 @@ from torch.nn import functional as F
...
@@ -32,8 +32,8 @@ from torch.nn import functional as F
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
.file_utils
import
cached_path
from
.file_utils
import
cached_path
from
.model_utils
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
from
.model_utils
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
,
PretrainedConfig
,
PreTrainedModel
,
SequenceSummary
)
SequenceSummary
,
PoolerAnswerClass
,
PoolerEndLogits
,
PoolerStartLogits
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -221,13 +221,15 @@ class XLNetConfig(PretrainedConfig):
...
@@ -221,13 +221,15 @@ class XLNetConfig(PretrainedConfig):
bi_data
=
False
,
bi_data
=
False
,
clamp_len
=-
1
,
clamp_len
=-
1
,
same_length
=
False
,
same_length
=
False
,
finetuning_task
=
None
,
finetuning_task
=
None
,
num_labels
=
2
,
num_labels
=
2
,
summary_type
=
'last'
,
summary_type
=
'last'
,
summary_use_proj
=
True
,
summary_use_proj
=
True
,
summary_activation
=
'tanh'
,
summary_activation
=
'tanh'
,
summary_dropout
=
0.1
,
summary_dropout
=
0.1
,
start_n_top
=
5
,
end_n_top
=
5
,
**
kwargs
):
**
kwargs
):
"""Constructs XLNetConfig.
"""Constructs XLNetConfig.
...
@@ -313,6 +315,8 @@ class XLNetConfig(PretrainedConfig):
...
@@ -313,6 +315,8 @@ class XLNetConfig(PretrainedConfig):
self
.
summary_use_proj
=
summary_use_proj
self
.
summary_use_proj
=
summary_use_proj
self
.
summary_activation
=
summary_activation
self
.
summary_activation
=
summary_activation
self
.
summary_dropout
=
summary_dropout
self
.
summary_dropout
=
summary_dropout
self
.
start_n_top
=
start_n_top
self
.
end_n_top
=
end_n_top
else
:
else
:
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
"or the path to a pretrained model config file (str)"
)
...
@@ -1114,6 +1118,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
...
@@ -1114,6 +1118,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
XLNetForSequenceClassification
,
self
).
__init__
(
config
)
super
(
XLNetForSequenceClassification
,
self
).
__init__
(
config
)
self
.
num_labels
=
config
.
num_labels
self
.
transformer
=
XLNetModel
(
config
)
self
.
transformer
=
XLNetModel
(
config
)
self
.
sequence_summary
=
SequenceSummary
(
config
)
self
.
sequence_summary
=
SequenceSummary
(
config
)
...
@@ -1174,7 +1179,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
...
@@ -1174,7 +1179,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
class
XLNetForQuestionAnswering
(
XLNetPreTrainedModel
):
class
XLNetForQuestionAnswering
(
XLNetPreTrainedModel
):
"""XLNet model for Question Answering (span extraction).
"""
XLNet model for Question Answering (span extraction).
This module is composed of the XLNet model with a linear layer on top of
This module is composed of the XLNet model with a linear layer on top of
the sequence output that computes start_logits and end_logits
the sequence output that computes start_logits and end_logits
...
@@ -1231,41 +1236,78 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
...
@@ -1231,41 +1236,78 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
XLNetForQuestionAnswering
,
self
).
__init__
(
config
)
super
(
XLNetForQuestionAnswering
,
self
).
__init__
(
config
)
self
.
start_n_top
=
config
.
start_n_top
self
.
end_n_top
=
config
.
end_n_top
self
.
transformer
=
XLNetModel
(
config
)
self
.
transformer
=
XLNetModel
(
config
)
self
.
qa_outputs
=
nn
.
Linear
(
config
.
d_model
,
config
.
num_labels
)
self
.
start_logits
=
PoolerStartLogits
(
config
)
self
.
end_logits
=
PoolerEndLogits
(
config
)
self
.
answer_class
=
PoolerAnswerClass
(
config
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
start_positions
=
None
,
end_positions
=
None
,
head_mask
=
None
):
start_positions
=
None
,
end_positions
=
None
,
cls_index
=
None
,
is_impossible
=
None
,
p_mask
=
None
,
head_mask
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
token_type_ids
,
input_mask
,
attention_mask
,
transformer_outputs
=
self
.
transformer
(
input_ids
,
token_type_ids
,
input_mask
,
attention_mask
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
head_mask
)
mems
,
perm_mask
,
target_mapping
,
inp_q
,
head_mask
)
hidden_states
=
transformer_outputs
[
0
]
start_logits
=
self
.
start_logits
(
hidden_states
,
p_mask
)
logits
=
self
.
qa_
outputs
(
transformer_outputs
[
0
])
outputs
=
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
if
start_positions
is
not
None
and
end_positions
is
not
None
:
start_logits
=
start_logits
.
squeeze
(
-
1
)
# If we are on multi-GPU, let's remove the dimension added by batch splitting
end_logits
=
end_logits
.
squeeze
(
-
1
)
for
x
in
(
start_positions
,
end_positions
,
cls_index
,
is_impossible
):
if
x
is
not
None
and
x
.
dim
()
>
1
:
x
.
squeeze_
(
-
1
)
outputs
=
(
start_logits
,
end_logits
,)
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
# during training, compute the end logits based on the ground truth of the start position
end_logits
=
self
.
end_logits
(
hidden_states
,
start_positions
=
start_positions
,
p_mask
=
p_mask
)
if
start_positions
is
not
None
and
end_positions
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
# If we are on multi-GPU, split add a dimension
if
len
(
start_positions
.
size
())
>
1
:
start_positions
=
start_positions
.
squeeze
(
-
1
)
if
len
(
end_positions
.
size
())
>
1
:
end_positions
=
end_positions
.
squeeze
(
-
1
)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index
=
start_logits
.
size
(
1
)
start_positions
.
clamp_
(
0
,
ignored_index
)
end_positions
.
clamp_
(
0
,
ignored_index
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=
ignored_index
)
start_loss
=
loss_fct
(
start_logits
,
start_positions
)
start_loss
=
loss_fct
(
start_logits
,
start_positions
)
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2
total_loss
=
(
start_loss
+
end_loss
)
/
2
outputs
=
(
total_loss
,)
+
outputs
return
outputs
# return (loss), logits, (mems), (hidden states), (attentions)
if
cls_index
is
not
None
and
is_impossible
is
not
None
:
# Predict answerability from the representation of CLS and START
cls_logits
=
self
.
answer_class
(
hidden_states
,
start_positions
=
start_positions
,
cls_index
=
cls_index
)
loss_fct_cls
=
nn
.
BCEWithLogitsLoss
()
cls_loss
=
loss_fct_cls
(
cls_logits
,
is_impossible
)
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is
# comparable to start_loss and end_loss
total_loss
+=
cls_loss
*
0.5
outputs
=
(
total_loss
,
start_logits
,
end_logits
,
cls_logits
)
+
outputs
else
:
outputs
=
(
total_loss
,
start_logits
,
end_logits
)
+
outputs
else
:
# during inference, compute the end logits based on beam search
bsz
,
slen
,
hsz
=
hidden_states
.
size
()
start_log_probs
=
F
.
softmax
(
start_logits
,
dim
=-
1
)
# shape (bsz, slen)
start_top_log_probs
,
start_top_index
=
torch
.
topk
(
start_log_probs
,
self
.
start_n_top
,
dim
=-
1
)
# shape (bsz, start_n_top)
start_top_index
=
start_top_index
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
hsz
)
# shape (bsz, start_n_top, hsz)
start_states
=
torch
.
gather
(
hidden_states
,
-
2
,
start_top_index
)
# shape (bsz, start_n_top, hsz)
start_states
=
start_states
.
unsqueeze
(
1
).
expand
(
-
1
,
slen
,
-
1
,
-
1
)
# shape (bsz, slen, start_n_top, hsz)
hidden_states_expanded
=
hidden_states
.
unsqueeze
(
2
).
expand_as
(
start_states
)
# shape (bsz, slen, start_n_top, hsz)
p_mask
=
p_mask
.
unsqueeze
(
-
1
)
if
p_mask
is
not
None
else
None
end_logits
=
self
.
end_logits
(
hidden_states_expanded
,
start_states
=
start_states
,
p_mask
=
p_mask
)
end_log_probs
=
F
.
softmax
(
end_logits
,
dim
=
1
)
# shape (bsz, slen, start_n_top)
end_top_log_probs
,
end_top_index
=
torch
.
topk
(
end_log_probs
,
self
.
end_n_top
,
dim
=
1
)
# shape (bsz, end_n_top, start_n_top)
end_top_log_probs
=
end_top_log_probs
.
view
(
-
1
,
self
.
start_n_top
*
self
.
end_n_top
)
end_top_index
=
end_top_index
.
view
(
-
1
,
self
.
start_n_top
*
self
.
end_n_top
)
start_states
=
torch
.
einsum
(
"blh,bl->bh"
,
hidden_states
,
start_log_probs
)
cls_logits
=
self
.
answer_class
(
hidden_states
,
start_states
=
start_states
,
cls_index
=
cls_index
)
outputs
=
(
start_top_log_probs
,
start_top_index
,
end_top_log_probs
,
end_top_index
,
cls_logits
)
+
outputs
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems, (hidden states), (attentions)
# or (if labels are provided) total_loss, start_logits, end_logits, (cls_logits), mems, (hidden states), (attentions)
return
outputs
pytorch_pretrained_bert/tests/modeling_openai_test.py
View file @
15b70338
...
@@ -16,11 +16,7 @@ from __future__ import absolute_import
...
@@ -16,11 +16,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
os
import
unittest
import
unittest
import
json
import
random
import
shutil
import
pytest
import
pytest
import
torch
import
torch
...
...
pytorch_pretrained_bert/tests/modeling_xlm_test.py
View file @
15b70338
...
@@ -20,7 +20,7 @@ import unittest
...
@@ -20,7 +20,7 @@ import unittest
import
shutil
import
shutil
import
pytest
import
pytest
from
pytorch_pretrained_bert
import
(
XLMConfig
,
XLMModel
,
XLMForQuestionAnswering
,
XLMForSequenceClassification
)
from
pytorch_pretrained_bert
import
(
XLMConfig
,
XLMModel
,
XLMWithLMHeadModel
,
XLMForQuestionAnswering
,
XLMForSequenceClassification
)
from
pytorch_pretrained_bert.modeling_xlm
import
PRETRAINED_MODEL_ARCHIVE_MAP
from
pytorch_pretrained_bert.modeling_xlm
import
PRETRAINED_MODEL_ARCHIVE_MAP
from
.model_tests_commons
import
(
create_and_check_commons
,
ConfigTester
,
ids_tensor
)
from
.model_tests_commons
import
(
create_and_check_commons
,
ConfigTester
,
ids_tensor
)
...
@@ -58,7 +58,8 @@ class XLMModelTest(unittest.TestCase):
...
@@ -58,7 +58,8 @@ class XLMModelTest(unittest.TestCase):
summary_type
=
"last"
,
summary_type
=
"last"
,
use_proj
=
True
,
use_proj
=
True
,
scope
=
None
,
scope
=
None
,
all_model_classes
=
(
XLMModel
,),
# , XLMForSequenceClassification, XLMForTokenClassification),
all_model_classes
=
(
XLMModel
,
XLMWithLMHeadModel
,
XLMForQuestionAnswering
,
XLMForSequenceClassification
),
# , XLMForSequenceClassification, XLMForTokenClassification),
):
):
self
.
parent
=
parent
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
...
@@ -93,6 +94,7 @@ class XLMModelTest(unittest.TestCase):
...
@@ -93,6 +94,7 @@ class XLMModelTest(unittest.TestCase):
def
prepare_config_and_inputs
(
self
):
def
prepare_config_and_inputs
(
self
):
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
2
).
float
()
input_lengths
=
None
input_lengths
=
None
if
self
.
use_input_lengths
:
if
self
.
use_input_lengths
:
...
@@ -104,11 +106,11 @@ class XLMModelTest(unittest.TestCase):
...
@@ -104,11 +106,11 @@ class XLMModelTest(unittest.TestCase):
sequence_labels
=
None
sequence_labels
=
None
token_labels
=
None
token_labels
=
None
choic
e_labels
=
None
is_impossibl
e_labels
=
None
if
self
.
use_labels
:
if
self
.
use_labels
:
sequence_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
sequence_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choic
e_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
is_impossibl
e_labels
=
ids_tensor
([
self
.
batch_size
],
2
).
float
(
)
config
=
XLMConfig
(
config
=
XLMConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
vocab_size_or_config_json_file
=
self
.
vocab_size
,
...
@@ -128,14 +130,14 @@ class XLMModelTest(unittest.TestCase):
...
@@ -128,14 +130,14 @@ class XLMModelTest(unittest.TestCase):
summary_type
=
self
.
summary_type
,
summary_type
=
self
.
summary_type
,
use_proj
=
self
.
use_proj
)
use_proj
=
self
.
use_proj
)
return
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
choice_labels
return
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
def
check_loss_output
(
self
,
result
):
def
check_loss_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
list
(
result
[
"loss"
].
size
()),
[])
[])
def
create_and_check_xlm_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_and_check_xlm_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
):
model
=
XLMModel
(
config
=
config
)
model
=
XLMModel
(
config
=
config
)
model
.
eval
()
model
.
eval
()
outputs
=
model
(
input_ids
,
lengths
=
input_lengths
,
langs
=
token_type_ids
)
outputs
=
model
(
input_ids
,
lengths
=
input_lengths
,
langs
=
token_type_ids
)
...
@@ -150,90 +152,92 @@ class XLMModelTest(unittest.TestCase):
...
@@ -150,90 +152,92 @@ class XLMModelTest(unittest.TestCase):
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
# def create_and_check_xlm_for_masked_lm(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
def
create_and_check_xlm_lm_head
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
):
# model = XLMForMaskedLM(config=config)
model
=
XLMWithLMHeadModel
(
config
)
# model.eval()
model
.
eval
()
# loss, prediction_scores = model(input_ids, token_type_ids, input_lengths, token_labels)
# result = {
loss
,
logits
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
)
# "loss": loss,
# "prediction_scores": prediction_scores,
result
=
{
# }
"loss"
:
loss
,
# self.parent.assertListEqual(
"logits"
:
logits
,
# list(result["prediction_scores"].size()),
}
# [self.batch_size, self.seq_length, self.vocab_size])
# self.check_loss_output(result)
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
# def create_and_check_xlm_for_question_answering(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
self
.
parent
.
assertListEqual
(
# model = XLMForQuestionAnswering(config=config)
list
(
result
[
"logits"
].
size
()),
# model.eval()
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
# loss, start_logits, end_logits = model(input_ids, token_type_ids, input_lengths, sequence_labels, sequence_labels)
# result = {
# "loss": loss,
def
create_and_check_xlm_qa
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
):
# "start_logits": start_logits,
model
=
XLMForQuestionAnswering
(
config
)
# "end_logits": end_logits,
model
.
eval
()
# }
# self.parent.assertListEqual(
outputs
=
model
(
input_ids
)
# list(result["start_logits"].size()),
start_top_log_probs
,
start_top_index
,
end_top_log_probs
,
end_top_index
,
cls_logits
,
mems
=
outputs
# [self.batch_size, self.seq_length])
# self.parent.assertListEqual(
outputs
=
model
(
input_ids
,
start_positions
=
sequence_labels
,
# list(result["end_logits"].size()),
end_positions
=
sequence_labels
,
# [self.batch_size, self.seq_length])
cls_index
=
sequence_labels
,
# self.check_loss_output(result)
is_impossible
=
is_impossible_labels
,
p_mask
=
input_mask
)
# def create_and_check_xlm_for_sequence_classification(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
outputs
=
model
(
input_ids
,
start_positions
=
sequence_labels
,
# config.num_labels = self.num_labels
end_positions
=
sequence_labels
,
# model = XLMForSequenceClassification(config)
cls_index
=
sequence_labels
,
# model.eval()
is_impossible
=
is_impossible_labels
)
# loss, logits = model(input_ids, token_type_ids, input_lengths, sequence_labels)
# result = {
total_loss
,
start_logits
,
end_logits
,
cls_logits
=
outputs
# "loss": loss,
# "logits": logits,
outputs
=
model
(
input_ids
,
start_positions
=
sequence_labels
,
# }
end_positions
=
sequence_labels
)
# self.parent.assertListEqual(
# list(result["logits"].size()),
total_loss
,
start_logits
,
end_logits
=
outputs
# [self.batch_size, self.num_labels])
# self.check_loss_output(result)
result
=
{
"loss"
:
total_loss
,
"start_logits"
:
start_logits
,
# def create_and_check_xlm_for_token_classification(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
"end_logits"
:
end_logits
,
# config.num_labels = self.num_labels
"cls_logits"
:
cls_logits
,
# model = XLMForTokenClassification(config=config)
}
# model.eval()
# loss, logits = model(input_ids, token_type_ids, input_lengths, token_labels)
self
.
parent
.
assertListEqual
(
# result = {
list
(
result
[
"loss"
].
size
()),
# "loss": loss,
[])
# "logits": logits,
self
.
parent
.
assertListEqual
(
# }
list
(
result
[
"start_logits"
].
size
()),
# self.parent.assertListEqual(
[
self
.
batch_size
,
self
.
seq_length
])
# list(result["logits"].size()),
self
.
parent
.
assertListEqual
(
# [self.batch_size, self.seq_length, self.num_labels])
list
(
result
[
"end_logits"
].
size
()),
# self.check_loss_output(result)
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"cls_logits"
].
size
()),
# def create_and_check_xlm_for_multiple_choice(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, choice_labels):
[
self
.
batch_size
])
# config.num_choices = self.num_choices
# model = XLMForMultipleChoice(config=config)
# model.eval()
def
create_and_check_xlm_sequence_classif
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
):
# multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
model
=
XLMForSequenceClassification
(
config
)
# multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
model
.
eval
()
# multiple_choice_input_lengths = input_lengths.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
# loss, logits = model(multiple_choice_inputs_ids,
(
logits
,)
=
model
(
input_ids
)
# multiple_choice_token_type_ids,
loss
,
logits
=
model
(
input_ids
,
labels
=
sequence_labels
)
# multiple_choice_input_lengths,
# choice_labels)
result
=
{
# result = {
"loss"
:
loss
,
# "loss": loss,
"logits"
:
logits
,
# "logits": logits,
}
# }
# self.parent.assertListEqual(
self
.
parent
.
assertListEqual
(
# list(result["logits"].size()),
list
(
result
[
"loss"
].
size
()),
# [self.batch_size, self.num_choices])
[])
# self.check_loss_output(result)
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
type_sequence_label_size
])
def
create_and_check_xlm_commons
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_and_check_xlm_commons
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
):
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'lengths'
:
input_lengths
}
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'lengths'
:
input_lengths
}
create_and_check_commons
(
self
,
config
,
inputs_dict
)
create_and_check_commons
(
self
,
config
,
inputs_dict
)
...
...
pytorch_pretrained_bert/tests/modeling_xlnet_test.py
View file @
15b70338
...
@@ -49,6 +49,7 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -49,6 +49,7 @@ class XLNetModelTest(unittest.TestCase):
d_inner
=
128
,
d_inner
=
128
,
num_hidden_layers
=
5
,
num_hidden_layers
=
5
,
max_position_embeddings
=
10
,
max_position_embeddings
=
10
,
type_sequence_label_size
=
2
,
untie_r
=
True
,
untie_r
=
True
,
bi_data
=
False
,
bi_data
=
False
,
same_length
=
False
,
same_length
=
False
,
...
@@ -80,12 +81,14 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -80,12 +81,14 @@ class XLNetModelTest(unittest.TestCase):
self
.
initializer_range
=
initializer_range
self
.
initializer_range
=
initializer_range
self
.
seed
=
seed
self
.
seed
=
seed
self
.
type_vocab_size
=
type_vocab_size
self
.
type_vocab_size
=
type_vocab_size
self
.
type_sequence_label_size
=
type_sequence_label_size
self
.
all_model_classes
=
all_model_classes
self
.
all_model_classes
=
all_model_classes
def
prepare_config_and_inputs
(
self
):
def
prepare_config_and_inputs
(
self
):
input_ids_1
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_1
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_2
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_2
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
segment_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
segment_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
2
).
float
()
input_ids_q
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
+
1
],
self
.
vocab_size
)
input_ids_q
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
+
1
],
self
.
vocab_size
)
perm_mask
=
torch
.
zeros
(
self
.
batch_size
,
self
.
seq_length
+
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
)
perm_mask
=
torch
.
zeros
(
self
.
batch_size
,
self
.
seq_length
+
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
)
...
@@ -94,30 +97,13 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -94,30 +97,13 @@ class XLNetModelTest(unittest.TestCase):
target_mapping
[:,
0
,
-
1
]
=
1.0
# predict last token
target_mapping
[:,
0
,
-
1
]
=
1.0
# predict last token
inp_q
=
target_mapping
[:,
0
,
:].
clone
()
# predict last token
inp_q
=
target_mapping
[:,
0
,
:].
clone
()
# predict last token
# inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
sequence_labels
=
None
# token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
# input_mask: float32 Tensor in shape [bsz, len], the input mask.
# 0 for real tokens and 1 for padding.
# mems: a list of float32 Tensors in shape [bsz, mem_len, hidden_size], memory
# from previous batches. The length of the list equals num_hidden_layers.
# If None, no memory is used.
# perm_mask: float32 Tensor in shape [bsz, len, len].
# If perm_mask[k, i, j] = 0, i attend to j in batch k;
# if perm_mask[k, i, j] = 1, i does not attend to j in batch k.
# If None, each position attends to all the others.
# target_mapping: float32 Tensor in shape [bsz, num_predict, len].
# If target_mapping[k, i, j] = 1, the i-th predict in batch k is
# on the j-th token.
# Only used during pretraining for partial prediction.
# Set to None during finetuning.
# inp_q: float32 Tensor in shape [bsz, len].
# 1 for tokens with losses and 0 for tokens without losses.
# Only used during pretraining for two-stream attention.
# Set to None during finetuning.
lm_labels
=
None
lm_labels
=
None
is_impossible_labels
=
None
if
self
.
use_labels
:
if
self
.
use_labels
:
lm_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
lm_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
sequence_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
is_impossible_labels
=
ids_tensor
([
self
.
batch_size
],
2
).
float
()
config
=
XLNetConfig
(
config
=
XLNetConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
vocab_size_or_config_json_file
=
self
.
vocab_size
,
...
@@ -132,18 +118,23 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -132,18 +118,23 @@ class XLNetModelTest(unittest.TestCase):
same_length
=
self
.
same_length
,
same_length
=
self
.
same_length
,
reuse_len
=
self
.
reuse_len
,
reuse_len
=
self
.
reuse_len
,
bi_data
=
self
.
bi_data
,
bi_data
=
self
.
bi_data
,
initializer_range
=
self
.
initializer_range
)
initializer_range
=
self
.
initializer_range
,
num_labels
=
self
.
type_sequence_label_size
)
return
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
)
return
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
)
def
set_seed
(
self
):
def
set_seed
(
self
):
random
.
seed
(
self
.
seed
)
random
.
seed
(
self
.
seed
)
torch
.
manual_seed
(
self
.
seed
)
torch
.
manual_seed
(
self
.
seed
)
def
create_and_check_xlnet_base_model
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
):
def
create_and_check_xlnet_base_model
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
model
=
XLNetModel
(
config
)
model
=
XLNetModel
(
config
)
model
.
eval
()
model
.
eval
()
_
,
_
=
model
(
input_ids_1
,
input_mask
=
input_mask
)
_
,
_
=
model
(
input_ids_1
,
attention_mask
=
input_mask
)
_
,
_
=
model
(
input_ids_1
,
token_type_ids
=
segment_ids
)
_
,
_
=
model
(
input_ids_1
,
token_type_ids
=
segment_ids
)
outputs
,
mems_1
=
model
(
input_ids_1
)
outputs
,
mems_1
=
model
(
input_ids_1
)
...
@@ -159,7 +150,8 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -159,7 +150,8 @@ class XLNetModelTest(unittest.TestCase):
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
):
def
create_and_check_xlnet_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
model
=
XLNetLMHeadModel
(
config
)
model
=
XLNetLMHeadModel
(
config
)
model
.
eval
()
model
.
eval
()
...
@@ -198,7 +190,82 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -198,7 +190,82 @@ class XLNetModelTest(unittest.TestCase):
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2"
]),
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_commons
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
):
def
create_and_check_xlnet_qa
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
model
=
XLNetForQuestionAnswering
(
config
)
model
.
eval
()
outputs
=
model
(
input_ids_1
)
start_top_log_probs
,
start_top_index
,
end_top_log_probs
,
end_top_index
,
cls_logits
,
mems
=
outputs
outputs
=
model
(
input_ids_1
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
cls_index
=
sequence_labels
,
is_impossible
=
is_impossible_labels
,
p_mask
=
input_mask
)
outputs
=
model
(
input_ids_1
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
cls_index
=
sequence_labels
,
is_impossible
=
is_impossible_labels
)
total_loss
,
start_logits
,
end_logits
,
cls_logits
,
mems
=
outputs
outputs
=
model
(
input_ids_1
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
)
total_loss
,
start_logits
,
end_logits
,
mems
=
outputs
result
=
{
"loss"
:
total_loss
,
"start_logits"
:
start_logits
,
"end_logits"
:
end_logits
,
"cls_logits"
:
cls_logits
,
"mems"
:
mems
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"cls_logits"
].
size
()),
[
self
.
batch_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_sequence_classif
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
model
=
XLNetForSequenceClassification
(
config
)
model
.
eval
()
logits
,
mems_1
=
model
(
input_ids_1
)
loss
,
logits
,
mems_1
=
model
(
input_ids_1
,
labels
=
sequence_labels
)
result
=
{
"loss"
:
loss
,
"mems_1"
:
mems_1
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
type_sequence_label_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_commons
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
create_and_check_commons
(
self
,
config
,
inputs_dict
,
test_pruning
=
False
)
create_and_check_commons
(
self
,
config
,
inputs_dict
,
test_pruning
=
False
)
...
@@ -224,27 +291,19 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -224,27 +291,19 @@ class XLNetModelTest(unittest.TestCase):
tester
.
set_seed
()
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_xlnet_lm_head
(
*
config_and_inputs
)
tester
.
create_and_check_xlnet_lm_head
(
*
config_and_inputs
)
tester
.
set_seed
()
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_xlnet_commons
(
*
config_and_inputs
)
tester
.
create_and_check_xlnet_sequence_classif
(
*
config_and_inputs
)
@
classmethod
def
mask_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
"""Creates a tensor with padding on the right (0.0 for )."""
if
rng
is
None
:
rng
=
random
.
Random
()
total_dims
=
1
for
dim
in
shape
:
total_dims
*=
dim
values
=
[]
tester
.
set_seed
()
for
_
in
range
(
total_dims
):
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
values
.
append
(
rng
.
r
andin
t
(
0
,
vocab_size
-
1
)
)
tester
.
create_and_check_xlnet_qa
(
*
config_
and
_
in
puts
)
return
torch
.
tensor
(
data
=
values
,
dtype
=
torch
.
long
).
view
(
shape
).
contiguous
()
tester
.
set_seed
()
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
create_and_check_xlnet_commons
(
*
config_and_inputs
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment