Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
483cbc36
Commit
483cbc36
authored
Jun 21, 2019
by
thomwolf
Browse files
test deviation with tf model: max ~1e-3 should be ok
parent
24d80689
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
358 additions
and
46 deletions
+358
-46
hubconfs/xlnet_hubconf.py
hubconfs/xlnet_hubconf.py
+169
-0
pytorch_pretrained_bert/__init__.py
pytorch_pretrained_bert/__init__.py
+1
-1
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+167
-26
tests/modeling_xlnet_test.py
tests/modeling_xlnet_test.py
+6
-6
tests/tokenization_xlnet_test.py
tests/tokenization_xlnet_test.py
+15
-13
No files found.
hubconfs/xlnet_hubconf.py
0 → 100644
View file @
483cbc36
from
pytorch_pretrained_bert.tokenization_xlnet
import
XLNetTokenizer
from
pytorch_pretrained_bert.modeling_xlnet
import
(
XLNetConfig
,
XLNetModel
,
XLNetLMHeadModel
,
XLNetForSequenceClassification
)
# A lot of models share the same param doc. Use a decorator
# to save typing
xlnet_docstring
=
"""
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `xlnet-large-cased`
- a path or url to a pretrained model archive containing:
. `config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a XLNetForPreTraining instance
- a path or url to a pretrained model archive containing:
. `xlnet_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific XLNet class
"""
def
_append_from_pretrained_docstring
(
docstr
):
def
docstring_decorator
(
fn
):
fn
.
__doc__
=
fn
.
__doc__
+
docstr
return
fn
return
docstring_decorator
def
xlnetTokenizer
(
*
args
,
**
kwargs
):
"""
Instantiate a XLNet sentencepiece tokenizer for XLNet from a pre-trained vocab file.
Peculiarities:
- require Google sentencepiece (https://github.com/google/sentencepiece)
Args:
pretrained_model_name_or_path: Path to pretrained model archive
or one of pre-trained vocab configs below.
* xlnet-large-cased
Keyword args:
special_tokens: Special tokens in vocabulary that are not pretrained
Default: None
max_len: An artificial maximum length to truncate tokenized sequences to;
Effective maximum length is always the minimum of this
value (if specified) and the underlying model's
sequence length.
Default: None
Example:
>>> import torch
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'xlnetTokenizer', 'xlnet-large-cased')
>>> text = "Who was Jim Henson ?"
>>> indexed_tokens = tokenizer.encode(tokenized_text)
"""
tokenizer
=
XLNetTokenizer
.
from_pretrained
(
*
args
,
**
kwargs
)
return
tokenizer
@
_append_from_pretrained_docstring
(
xlnet_docstring
)
def
xlnetModel
(
*
args
,
**
kwargs
):
"""
xlnetModel is the basic XLNet Transformer model from
"XLNet: Generalized Autoregressive Pretraining for Language Understanding"
by Zhilin Yang, Zihang Dai1, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le
Example:
# Load the tokenizer
>>> import torch
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'xlnetTokenizer', 'xlnet-large-cased')
# Prepare tokenized input
>>> text_1 = "Who was Jim Henson ?"
>>> text_2 = "Jim Henson was a puppeteer"
>>> indexed_tokens_1 = tokenizer.encode(text_1)
>>> indexed_tokens_2 = tokenizer.encode(text_2)
>>> tokens_tensor_1 = torch.tensor([indexed_tokens_1])
>>> tokens_tensor_2 = torch.tensor([indexed_tokens_2])
# Load xlnetModel
>>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'xlnetModel', 'xlnet-large-cased')
>>> model.eval()
# Predict hidden states features for each layer
>>> with torch.no_grad():
hidden_states_1, mems = model(tokens_tensor_1)
hidden_states_2, mems = model(tokens_tensor_2, past=mems)
"""
model
=
XLNetModel
.
from_pretrained
(
*
args
,
**
kwargs
)
return
model
@
_append_from_pretrained_docstring
(
xlnet_docstring
)
def
xlnetLMHeadModel
(
*
args
,
**
kwargs
):
"""
xlnetModel is the basic XLNet Transformer model from
"XLNet: Generalized Autoregressive Pretraining for Language Understanding"
by Zhilin Yang, Zihang Dai1, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le
with a tied (pre-trained) language modeling head on top.
Example:
# Load the tokenizer
>>> import torch
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'xlnetTokenizer', 'xlnet-large-cased')
# Prepare tokenized input
>>> text_1 = "Who was Jim Henson ?"
>>> text_2 = "Jim Henson was a puppeteer"
>>> indexed_tokens_1 = tokenizer.encode(text_1)
>>> indexed_tokens_2 = tokenizer.encode(text_2)
>>> tokens_tensor_1 = torch.tensor([indexed_tokens_1])
>>> tokens_tensor_2 = torch.tensor([indexed_tokens_2])
# Load xlnetLMHeadModel
>>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'xlnetLMHeadModel', 'xlnet-large-cased')
>>> model.eval()
# Predict hidden states features for each layer
>>> with torch.no_grad():
predictions_1, mems = model(tokens_tensor_1)
predictions_2, mems = model(tokens_tensor_2, mems=mems)
# Get the predicted last token
>>> predicted_index = torch.argmax(predictions_2[0, -1, :]).item()
>>> predicted_token = tokenizer.decode([predicted_index])
>>> assert predicted_token == ' who'
"""
model
=
XLNetLMHeadModel
.
from_pretrained
(
*
args
,
**
kwargs
)
return
model
@
_append_from_pretrained_docstring
(
xlnet_docstring
)
def
xlnetForSequenceClassification
(
*
args
,
**
kwargs
):
"""
xlnetModel is the basic XLNet Transformer model from
"XLNet: Generalized Autoregressive Pretraining for Language Understanding"
by Zhilin Yang, Zihang Dai1, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le
Example:
# Load the tokenizer
>>> import torch
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'xlnetTokenizer', 'xlnet-large-cased')
# Prepare tokenized input
>>> text1 = "Who was Jim Henson ? Jim Henson was a puppeteer"
>>> text2 = "Who was Jim Henson ? Jim Henson was a mysterious young man"
>>> tokenized_text1 = tokenizer.tokenize(text1)
>>> tokenized_text2 = tokenizer.tokenize(text2)
>>> indexed_tokens1 = tokenizer.convert_tokens_to_ids(tokenized_text1)
>>> indexed_tokens2 = tokenizer.convert_tokens_to_ids(tokenized_text2)
>>> tokens_tensor = torch.tensor([[indexed_tokens1, indexed_tokens2]])
>>> mc_token_ids = torch.LongTensor([[len(tokenized_text1)-1, len(tokenized_text2)-1]])
# Load xlnetForSequenceClassification
>>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'xlnetForSequenceClassification', 'xlnet-large-cased')
>>> model.eval()
# Predict sequence classes logits
>>> with torch.no_grad():
lm_logits, mems = model(tokens_tensor)
"""
model
=
XLNetForSequenceClassification
.
from_pretrained
(
*
args
,
**
kwargs
)
return
model
pytorch_pretrained_bert/__init__.py
View file @
483cbc36
...
@@ -3,7 +3,7 @@ from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
...
@@ -3,7 +3,7 @@ from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
from
.tokenization_openai
import
OpenAIGPTTokenizer
from
.tokenization_openai
import
OpenAIGPTTokenizer
from
.tokenization_transfo_xl
import
(
TransfoXLTokenizer
,
TransfoXLCorpus
)
from
.tokenization_transfo_xl
import
(
TransfoXLTokenizer
,
TransfoXLCorpus
)
from
.tokenization_gpt2
import
GPT2Tokenizer
from
.tokenization_gpt2
import
GPT2Tokenizer
from
.tokenization_xlnet
import
XLNetTokenizer
from
.tokenization_xlnet
import
XLNetTokenizer
,
SPIECE_UNDERLINE
from
.modeling
import
(
BertConfig
,
BertModel
,
BertForPreTraining
,
from
.modeling
import
(
BertConfig
,
BertModel
,
BertForPreTraining
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
...
...
pytorch_pretrained_bert/modeling_xlnet.py
View file @
483cbc36
...
@@ -165,12 +165,12 @@ def load_tf_weights_in_xlnet(model, config, tf_path):
...
@@ -165,12 +165,12 @@ def load_tf_weights_in_xlnet(model, config, tf_path):
def
gelu
(
x
):
def
gelu
(
x
):
"""Implementation of the gelu activation function.
""" Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
XLNet is using OpenAI GPT's gelu (not exactly the same as BERT)
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
Also see https://arxiv.org/abs/1606.08415
Also see https://arxiv.org/abs/1606.08415
"""
"""
return
x
*
0.5
*
(
1.0
+
torch
.
erf
(
x
/
math
.
sqrt
(
2.0
)))
cdf
=
0.5
*
(
1.0
+
torch
.
tanh
(
math
.
sqrt
(
2
/
math
.
pi
)
*
(
x
+
0.044715
*
torch
.
pow
(
x
,
3
))))
return
x
*
cdf
def
swish
(
x
):
def
swish
(
x
):
...
@@ -657,7 +657,7 @@ class XLNetPreTrainedModel(nn.Module):
...
@@ -657,7 +657,7 @@ class XLNetPreTrainedModel(nn.Module):
- a str with the name of a pre-trained model to load selected in the list of:
- a str with the name of a pre-trained model to load selected in the list of:
. `xlnet-large-cased`
. `xlnet-large-cased`
- a path or url to a pretrained model archive containing:
- a path or url to a pretrained model archive containing:
. `
xlnet_
config.json` a configuration file for the model
. `config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a XLNetForPreTraining instance
. `pytorch_model.bin` a PyTorch dump of a XLNetForPreTraining instance
- a path or url to a pretrained model archive containing:
- a path or url to a pretrained model archive containing:
. `xlnet_config.json` a configuration file for the model
. `xlnet_config.json` a configuration file for the model
...
@@ -767,6 +767,8 @@ class XLNetPreTrainedModel(nn.Module):
...
@@ -767,6 +767,8 @@ class XLNetPreTrainedModel(nn.Module):
if
len
(
error_msgs
)
>
0
:
if
len
(
error_msgs
)
>
0
:
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
if
isinstance
(
model
,
XLNetLMHeadModel
):
model
.
tie_weights
()
# make sure word embedding weights are still tied
return
model
return
model
...
@@ -894,23 +896,23 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -894,23 +896,23 @@ class XLNetModel(XLNetPreTrainedModel):
output_all_encoded_layers
=
True
,
head_mask
=
None
):
output_all_encoded_layers
=
True
,
head_mask
=
None
):
"""
"""
Args:
Args:
inp_k: int32 Tensor in shape [
len, bsz
], the input token IDs.
inp_k: int32 Tensor in shape [
bsz, len
], the input token IDs.
seg_id: int32 Tensor in shape [
len, bsz
], the input segment IDs.
seg_id: int32 Tensor in shape [
bsz, len
], the input segment IDs.
input_mask: [optional] float32 Tensor in shape [
len, bsz
], the input mask.
input_mask: [optional] float32 Tensor in shape [
bsz, len
], the input mask.
0 for real tokens and 1 for padding.
0 for real tokens and 1 for padding.
mems: [optional] a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
mems: [optional] a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer.
from previous batches. The length of the list equals n_layer.
If None, no memory is used.
If None, no memory is used.
perm_mask: [optional] float32 Tensor in shape [len, len
, bsz
].
perm_mask: [optional] float32 Tensor in shape [
bsz,
len, len].
If perm_mask[i, j
, k
] = 0, i attend to j in batch k;
If perm_mask[
k,
i, j] = 0, i attend to j in batch k;
if perm_mask[i, j
, k
] = 1, i does not attend to j in batch k.
if perm_mask[
k,
i, j] = 1, i does not attend to j in batch k.
If None, each position attends to all the others.
If None, each position attends to all the others.
target_mapping: [optional] float32 Tensor in shape [num_predict, len
, bsz
].
target_mapping: [optional] float32 Tensor in shape [
bsz,
num_predict, len].
If target_mapping[i, j
, k
] = 1, the i-th predict in batch k is
If target_mapping[
k,
i, j] = 1, the i-th predict in batch k is
on the j-th token.
on the j-th token.
Only used during pretraining for partial prediction.
Only used during pretraining for partial prediction.
Set to None during finetuning.
Set to None during finetuning.
inp_q: [optional] float32 Tensor in shape [
len, bsz
].
inp_q: [optional] float32 Tensor in shape [
bsz, len
].
1 for tokens with losses and 0 for tokens without losses.
1 for tokens with losses and 0 for tokens without losses.
Only used during pretraining for two-stream attention.
Only used during pretraining for two-stream attention.
Set to None during finetuning.
Set to None during finetuning.
...
@@ -926,6 +928,16 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -926,6 +928,16 @@ class XLNetModel(XLNetPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
to pool the input to get a vector representation.
"""
"""
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension
# so we move here the first dimension (batch) to the end
inp_k
=
inp_k
.
transpose
(
0
,
1
).
contiguous
()
seg_id
=
seg_id
.
transpose
(
0
,
1
).
contiguous
()
if
seg_id
is
not
None
else
None
input_mask
=
input_mask
.
transpose
(
0
,
1
).
contiguous
()
if
input_mask
is
not
None
else
None
perm_mask
=
perm_mask
.
permute
(
1
,
2
,
0
).
contiguous
()
if
perm_mask
is
not
None
else
None
target_mapping
=
target_mapping
.
permute
(
1
,
2
,
0
).
contiguous
()
if
target_mapping
is
not
None
else
None
inp_q
=
inp_q
.
transpose
(
0
,
1
).
contiguous
()
if
inp_q
is
not
None
else
None
qlen
,
bsz
=
inp_k
.
shape
[
0
],
inp_k
.
shape
[
1
]
qlen
,
bsz
=
inp_k
.
shape
[
0
],
inp_k
.
shape
[
1
]
mlen
=
mems
[
0
].
shape
[
0
]
if
mems
is
not
None
else
0
mlen
=
mems
[
0
].
shape
[
0
]
if
mems
is
not
None
else
0
klen
=
mlen
+
qlen
klen
=
mlen
+
qlen
...
@@ -1020,6 +1032,7 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -1020,6 +1032,7 @@ class XLNetModel(XLNetPreTrainedModel):
if
mems
is
None
:
if
mems
is
None
:
mems
=
[
None
]
*
len
(
self
.
layer
)
mems
=
[
None
]
*
len
(
self
.
layer
)
hidden_states
=
[]
for
i
,
layer_module
in
enumerate
(
self
.
layer
):
for
i
,
layer_module
in
enumerate
(
self
.
layer
):
# cache new mems
# cache new mems
new_mems
.
append
(
self
.
cache_mem
(
output_h
,
mems
[
i
]))
new_mems
.
append
(
self
.
cache_mem
(
output_h
,
mems
[
i
]))
...
@@ -1029,10 +1042,14 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -1029,10 +1042,14 @@ class XLNetModel(XLNetPreTrainedModel):
r
=
pos_emb
,
seg_mat
=
seg_mat
,
r
=
pos_emb
,
seg_mat
=
seg_mat
,
mems
=
mems
[
i
],
target_mapping
=
target_mapping
,
mems
=
mems
[
i
],
target_mapping
=
target_mapping
,
head_mask
=
head_mask
)
head_mask
=
head_mask
)
hidden_states
.
append
(
output_h
)
output
=
self
.
dropout
(
output_g
if
output_g
is
not
None
else
output_h
)
output
=
self
.
dropout
(
output_g
if
output_g
is
not
None
else
output_h
)
return
output
,
new_mems
# We transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
output
=
output
.
permute
(
1
,
0
,
2
).
contiguous
()
hidden_states
=
[
hs
.
permute
(
1
,
0
,
2
).
contiguous
()
for
hs
in
hidden_states
]
return
output
,
hidden_states
,
new_mems
class
XLNetLMHeadModel
(
XLNetPreTrainedModel
):
class
XLNetLMHeadModel
(
XLNetPreTrainedModel
):
...
@@ -1110,23 +1127,23 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -1110,23 +1127,23 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
target
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
target
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
"""
"""
Args:
Args:
inp_k: int32 Tensor in shape [
len, bsz
], the input token IDs.
inp_k: int32 Tensor in shape [
bsz, len
], the input token IDs.
seg_id: int32 Tensor in shape [
len, bsz
], the input segment IDs.
seg_id: int32 Tensor in shape [
bsz, len
], the input segment IDs.
input_mask: float32 Tensor in shape [
len, bsz
], the input mask.
input_mask: float32 Tensor in shape [
bsz, len
], the input mask.
0 for real tokens and 1 for padding.
0 for real tokens and 1 for padding.
mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer.
from previous batches. The length of the list equals n_layer.
If None, no memory is used.
If None, no memory is used.
perm_mask: float32 Tensor in shape [len, len
, bsz
].
perm_mask: float32 Tensor in shape [
bsz,
len, len].
If perm_mask[i, j
, k
] = 0, i attend to j in batch k;
If perm_mask[
k,
i, j] = 0, i attend to j in batch k;
if perm_mask[i, j
, k
] = 1, i does not attend to j in batch k.
if perm_mask[
k,
i, j] = 1, i does not attend to j in batch k.
If None, each position attends to all the others.
If None, each position attends to all the others.
target_mapping: float32 Tensor in shape [num_predict, len
, bsz
].
target_mapping: float32 Tensor in shape [
bsz,
num_predict, len].
If target_mapping[i, j
, k
] = 1, the i-th predict in batch k is
If target_mapping[
k,
i, j] = 1, the i-th predict in batch k is
on the j-th token.
on the j-th token.
Only used during pretraining for partial prediction.
Only used during pretraining for partial prediction.
Set to None during finetuning.
Set to None during finetuning.
inp_q: float32 Tensor in shape [
len, bsz
].
inp_q: float32 Tensor in shape [
bsz, len
].
1 for tokens with losses and 0 for tokens without losses.
1 for tokens with losses and 0 for tokens without losses.
Only used during pretraining for two-stream attention.
Only used during pretraining for two-stream attention.
Set to None during finetuning.
Set to None during finetuning.
...
@@ -1134,7 +1151,131 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -1134,7 +1151,131 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
to pool the input to get a vector representation.
"""
"""
output
,
new_mems
=
self
.
transformer
(
inp_k
,
seg_id
,
input_mask
,
output
,
hidden_states
,
new_mems
=
self
.
transformer
(
inp_k
,
seg_id
,
input_mask
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
output_all_encoded_layers
,
head_mask
)
logits
=
self
.
lm_loss
(
output
)
if
target
is
not
None
:
# Flatten the tokens
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
logits
.
view
(
-
1
,
logits
.
size
(
-
1
)),
target
.
view
(
-
1
))
return
loss
,
new_mems
# if self.output_attentions:
# all_attentions, encoded_layers = encoded_layers
# sequence_output = encoded_layers[-1]
# pooled_output = self.pooler(sequence_output)
# if not output_all_encoded_layers:
# encoded_layers = encoded_layers[-1]
# if self.output_attentions:
return
logits
,
new_mems
# return all_attentions, encoded_layers, pooled_output
class
XLNetForSequenceClassification
(
XLNetPreTrainedModel
):
"""XLNet model ("XLNet: Generalized Autoregressive Pretraining for Language Understanding").
Params:
`config`: a XLNetConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
`summary_type`: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation. Default: last
Inputs:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
seg_id: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding.
mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer.
If None, no memory is used.
perm_mask: float32 Tensor in shape [bsz, len, len].
If perm_mask[k, i, j] = 0, i attend to j in batch k;
if perm_mask[k, i, j] = 1, i does not attend to j in batch k.
If None, each position attends to all the others.
target_mapping: float32 Tensor in shape [bsz, num_predict, len].
If target_mapping[k, i, j] = 1, the i-th predict in batch k is
on the j-th token.
Only used during pretraining for partial prediction.
Set to None during finetuning.
inp_q: float32 Tensor in shape [bsz, len].
1 for tokens with losses and 0 for tokens without losses.
Only used during pretraining for two-stream attention.
Set to None during finetuning.
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs: Tuple of (logits or loss, mems)
`logits or loss`:
if target is None:
Token logits with shape [batch_size, sequence_length]
else:
CrossEntropy loss with the targets
`new_mems`: list (num layers) of updated mem states at the entry of each layer
each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `target`
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = modeling.XLNetConfig(vocab_size_or_config_json_file=32000, d_model=768,
n_layer=12, num_attention_heads=12, intermediate_size=3072)
model = modeling.XLNetModel(config=config)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
summary_type
=
"last"
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
XLNetForSequenceClassification
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
attn_type
=
config
.
attn_type
self
.
same_length
=
config
.
same_length
self
.
summary_type
=
summary_type
self
.
transformer
=
XLNetModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
lm_loss
=
nn
.
Linear
(
config
.
d_model
,
config
.
n_token
,
bias
=
True
)
self
.
apply
(
self
.
init_xlnet_weights
)
self
.
tie_weights
()
def
forward
(
self
,
inp_k
,
seg_id
=
None
,
input_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
target
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
"""
Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
seg_id: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding.
mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer.
If None, no memory is used.
perm_mask: float32 Tensor in shape [bsz, len, len].
If perm_mask[k, i, j] = 0, i attend to j in batch k;
if perm_mask[k, i, j] = 1, i does not attend to j in batch k.
If None, each position attends to all the others.
target_mapping: float32 Tensor in shape [bsz, num_predict, len].
If target_mapping[k, i, j] = 1, the i-th predict in batch k is
on the j-th token.
Only used during pretraining for partial prediction.
Set to None during finetuning.
inp_q: float32 Tensor in shape [bsz, len].
1 for tokens with losses and 0 for tokens without losses.
Only used during pretraining for two-stream attention.
Set to None during finetuning.
"""
output
,
hidden_states
,
new_mems
=
self
.
transformer
(
inp_k
,
seg_id
,
input_mask
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
output_all_encoded_layers
,
head_mask
)
output_all_encoded_layers
,
head_mask
)
...
...
tests/modeling_xlnet_test.py
View file @
483cbc36
...
@@ -74,9 +74,9 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -74,9 +74,9 @@ class XLNetModelTest(unittest.TestCase):
self
.
type_vocab_size
=
type_vocab_size
self
.
type_vocab_size
=
type_vocab_size
def
prepare_config_and_inputs
(
self
):
def
prepare_config_and_inputs
(
self
):
input_ids_1
=
XLNetModelTest
.
ids_tensor
([
self
.
seq_length
,
self
.
batch_size
],
self
.
vocab_size
)
input_ids_1
=
XLNetModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_2
=
XLNetModelTest
.
ids_tensor
([
self
.
seq_length
,
self
.
batch_size
],
self
.
vocab_size
)
input_ids_2
=
XLNetModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
segment_ids
=
XLNetModelTest
.
ids_tensor
([
self
.
seq_length
,
self
.
batch_size
],
self
.
type_vocab_size
)
segment_ids
=
XLNetModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
# inp_k: int32 Tensor in shape [len, bsz], the input token IDs.
# inp_k: int32 Tensor in shape [len, bsz], the input token IDs.
# seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
# seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
...
@@ -101,7 +101,7 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -101,7 +101,7 @@ class XLNetModelTest(unittest.TestCase):
lm_labels
=
None
lm_labels
=
None
if
self
.
use_labels
:
if
self
.
use_labels
:
lm_labels
=
XLNetModelTest
.
ids_tensor
([
self
.
seq_length
,
self
.
batch_size
],
self
.
vocab_size
)
lm_labels
=
XLNetModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
config
=
XLNetConfig
(
config
=
XLNetConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
vocab_size_or_config_json_file
=
self
.
vocab_size
,
...
@@ -155,7 +155,7 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -155,7 +155,7 @@ class XLNetModelTest(unittest.TestCase):
[])
[])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits_1"
].
size
()),
list
(
result
[
"lm_logits_1"
].
size
()),
[
self
.
seq_length
,
self
.
batch_size
,
self
.
vocab_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1a"
]),
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1a"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
...
@@ -171,7 +171,7 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -171,7 +171,7 @@ class XLNetModelTest(unittest.TestCase):
[])
[])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits_2"
].
size
()),
list
(
result
[
"lm_logits_2"
].
size
()),
[
self
.
seq_length
,
self
.
batch_size
,
self
.
vocab_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2a"
]),
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2a"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
...
...
tests/tokenization_xlnet_test.py
View file @
483cbc36
...
@@ -20,7 +20,9 @@ from io import open
...
@@ -20,7 +20,9 @@ from io import open
import
shutil
import
shutil
import
pytest
import
pytest
from
pytorch_pretrained_bert.tokenization_xlnet
import
(
XLNetTokenizer
,
PRETRAINED_VOCAB_ARCHIVE_MAP
)
from
pytorch_pretrained_bert.tokenization_xlnet
import
(
XLNetTokenizer
,
PRETRAINED_VOCAB_ARCHIVE_MAP
,
SPIECE_UNDERLINE
)
SAMPLE_VOCAB
=
os
.
path
.
join
(
os
.
path
.
dirname
(
SAMPLE_VOCAB
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))),
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))),
...
@@ -45,9 +47,9 @@ class XLNetTokenizationTest(unittest.TestCase):
...
@@ -45,9 +47,9 @@ class XLNetTokenizationTest(unittest.TestCase):
os
.
remove
(
special_tokens_file
)
os
.
remove
(
special_tokens_file
)
tokens
=
tokenizer
.
tokenize
(
"I was born in 92000, and this is falsé."
)
tokens
=
tokenizer
.
tokenize
(
"I was born in 92000, and this is falsé."
)
self
.
assertListEqual
(
tokens
,
[
'▁I'
,
'▁was'
,
'▁b'
,
'or'
,
'n'
,
'▁in'
,
'▁
'
,
self
.
assertListEqual
(
tokens
,
[
SPIECE_UNDERLINE
+
'I'
,
SPIECE_UNDERLINE
+
'was'
,
SPIECE_UNDERLINE
+
'b'
,
'or'
,
'n'
,
SPIECE_UNDERLINE
+
'in'
,
SPIECE_UNDERLINE
+
'
'
,
'9'
,
'2'
,
'0'
,
'0'
,
'0'
,
','
,
'▁and'
,
'▁
this'
,
'9'
,
'2'
,
'0'
,
'0'
,
'0'
,
','
,
SPIECE_UNDERLINE
+
'and'
,
SPIECE_UNDERLINE
+
'
this'
,
'▁is'
,
'▁
f'
,
'al'
,
's'
,
'é'
,
'.'
])
SPIECE_UNDERLINE
+
'is'
,
SPIECE_UNDERLINE
+
'
f'
,
'al'
,
's'
,
'é'
,
'.'
])
ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
self
.
assertListEqual
(
self
.
assertListEqual
(
ids
,
[
8
,
21
,
84
,
55
,
24
,
19
,
7
,
0
,
ids
,
[
8
,
21
,
84
,
55
,
24
,
19
,
7
,
0
,
...
@@ -55,9 +57,9 @@ class XLNetTokenizationTest(unittest.TestCase):
...
@@ -55,9 +57,9 @@ class XLNetTokenizationTest(unittest.TestCase):
46
,
72
,
80
,
6
,
0
,
4
])
46
,
72
,
80
,
6
,
0
,
4
])
back_tokens
=
tokenizer
.
convert_ids_to_tokens
(
ids
)
back_tokens
=
tokenizer
.
convert_ids_to_tokens
(
ids
)
self
.
assertListEqual
(
back_tokens
,
[
'▁I'
,
'▁was'
,
'▁b'
,
'or'
,
'n'
,
'▁
in'
,
self
.
assertListEqual
(
back_tokens
,
[
SPIECE_UNDERLINE
+
'I'
,
SPIECE_UNDERLINE
+
'was'
,
SPIECE_UNDERLINE
+
'b'
,
'or'
,
'n'
,
SPIECE_UNDERLINE
+
'
in'
,
'
▁
'
,
'<unk>'
,
'2'
,
'0'
,
'0'
,
'0'
,
','
,
SPIECE_UNDERLINE
+
''
,
'<unk>'
,
'2'
,
'0'
,
'0'
,
'0'
,
','
,
'▁and'
,
'▁this'
,
'▁is'
,
'▁
f'
,
'al'
,
's'
,
SPIECE_UNDERLINE
+
'and'
,
SPIECE_UNDERLINE
+
'this'
,
SPIECE_UNDERLINE
+
'is'
,
SPIECE_UNDERLINE
+
'
f'
,
'al'
,
's'
,
'<unk>'
,
'.'
])
'<unk>'
,
'.'
])
@
pytest
.
mark
.
slow
@
pytest
.
mark
.
slow
...
@@ -71,17 +73,17 @@ class XLNetTokenizationTest(unittest.TestCase):
...
@@ -71,17 +73,17 @@ class XLNetTokenizationTest(unittest.TestCase):
def
test_tokenizer_lower
(
self
):
def
test_tokenizer_lower
(
self
):
tokenizer
=
XLNetTokenizer
(
SAMPLE_VOCAB
,
do_lower_case
=
True
)
tokenizer
=
XLNetTokenizer
(
SAMPLE_VOCAB
,
do_lower_case
=
True
)
tokens
=
tokenizer
.
tokenize
(
u
"I was born in 92000, and this is falsé."
)
tokens
=
tokenizer
.
tokenize
(
u
"I was born in 92000, and this is falsé."
)
self
.
assertListEqual
(
tokens
,
[
'▁'
,
'i'
,
'▁was'
,
'▁b'
,
'or'
,
'n'
,
'▁in'
,
'▁
'
,
self
.
assertListEqual
(
tokens
,
[
SPIECE_UNDERLINE
+
''
,
'i'
,
SPIECE_UNDERLINE
+
'was'
,
SPIECE_UNDERLINE
+
'b'
,
'or'
,
'n'
,
SPIECE_UNDERLINE
+
'in'
,
SPIECE_UNDERLINE
+
'
'
,
'9'
,
'2'
,
'0'
,
'0'
,
'0'
,
','
,
'▁and'
,
'▁
this'
,
'9'
,
'2'
,
'0'
,
'0'
,
'0'
,
','
,
SPIECE_UNDERLINE
+
'and'
,
SPIECE_UNDERLINE
+
'
this'
,
'▁is'
,
'▁
f'
,
'al'
,
'se'
,
'.'
])
SPIECE_UNDERLINE
+
'is'
,
SPIECE_UNDERLINE
+
'
f'
,
'al'
,
'se'
,
'.'
])
self
.
assertListEqual
(
tokenizer
.
tokenize
(
u
"H
\u00E9
llo"
),
[
"▁he"
,
"ll"
,
"o"
])
self
.
assertListEqual
(
tokenizer
.
tokenize
(
u
"H
\u00E9
llo"
),
[
"▁he"
,
"ll"
,
"o"
])
def
test_tokenizer_no_lower
(
self
):
def
test_tokenizer_no_lower
(
self
):
tokenizer
=
XLNetTokenizer
(
SAMPLE_VOCAB
,
do_lower_case
=
False
)
tokenizer
=
XLNetTokenizer
(
SAMPLE_VOCAB
,
do_lower_case
=
False
)
tokens
=
tokenizer
.
tokenize
(
u
"I was born in 92000, and this is falsé."
)
tokens
=
tokenizer
.
tokenize
(
u
"I was born in 92000, and this is falsé."
)
self
.
assertListEqual
(
tokens
,
[
'▁I'
,
'▁was'
,
'▁b'
,
'or'
,
'n'
,
'▁in'
,
'▁
'
,
self
.
assertListEqual
(
tokens
,
[
SPIECE_UNDERLINE
+
'I'
,
SPIECE_UNDERLINE
+
'was'
,
SPIECE_UNDERLINE
+
'b'
,
'or'
,
'n'
,
SPIECE_UNDERLINE
+
'in'
,
SPIECE_UNDERLINE
+
'
'
,
'9'
,
'2'
,
'0'
,
'0'
,
'0'
,
','
,
'▁and'
,
'▁
this'
,
'9'
,
'2'
,
'0'
,
'0'
,
'0'
,
','
,
SPIECE_UNDERLINE
+
'and'
,
SPIECE_UNDERLINE
+
'
this'
,
'▁is'
,
'▁
f'
,
'al'
,
'se'
,
'.'
])
SPIECE_UNDERLINE
+
'is'
,
SPIECE_UNDERLINE
+
'
f'
,
'al'
,
'se'
,
'.'
])
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment