Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ebd2cb8d
Commit
ebd2cb8d
authored
Jun 21, 2019
by
thomwolf
Browse files
update from_pretrained to load XLNetModel as well
parent
483cbc36
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
99 additions
and
36 deletions
+99
-36
examples/generation_xlnet.py
examples/generation_xlnet.py
+21
-0
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+33
-18
pytorch_pretrained_bert/tokenization_xlnet.py
pytorch_pretrained_bert/tokenization_xlnet.py
+15
-0
tests/modeling_xlnet_test.py
tests/modeling_xlnet_test.py
+30
-18
No files found.
examples/generation_xlnet.py
0 → 100644
View file @
ebd2cb8d
import
torch
from
torch.nn
import
functional
as
F
from
pytorch_pretrained_bert
import
XLNetModel
,
XLNetLMHeadModel
,
XLNetTokenizer
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
)
tokenizer
=
XLNetTokenizer
.
from_pretrained
(
'xlnet-large-cased'
)
model
=
XLNetModel
.
from_pretrained
(
'xlnet-large-cased'
)
model
=
XLNetLMHeadModel
.
from_pretrained
(
'xlnet-large-cased'
)
tokens
=
tokenizer
.
encode
(
'I am very '
)
for
i
in
range
(
len
(
tokens
),
20
):
mask
=
torch
.
tensor
([[[
0.0
]
*
i
+
[
1.0
]]])
logits
,
_
=
model
(
torch
.
tensor
([
tokens
+
[
0
]]),
perm_mask
=
mask
.
expand
(
-
1
,
i
+
1
,
-
1
),
target_mapping
=
mask
,
inp_q
=
mask
.
squeeze
(
1
))
output
=
torch
.
multinomial
(
F
.
softmax
(
logits
[
0
,
0
,
:]),
1
)
tokens
.
append
(
output
.
item
())
print
(
tokenizer
.
decode
(
tokens
))
pytorch_pretrained_bert/modeling_xlnet.py
View file @
ebd2cb8d
...
...
@@ -727,16 +727,24 @@ class XLNetPreTrainedModel(nn.Module):
archive_file
,
resolved_archive_file
))
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
config_file
,
resolved_config_file
))
# Load config
config
=
XLNetConfig
.
from_json_file
(
resolved_config_file
)
logger
.
info
(
"Model config {}"
.
format
(
config
))
# Update config with kwargs if needed
for
key
,
value
in
kwargs
:
if
hasattr
(
config
,
key
):
setattr
(
config
,
key
,
value
)
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
if
state_dict
is
None
and
not
from_tf
:
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
if
from_tf
:
# Directly load from a TensorFlow checkpoint
return
load_tf_weights_in_xlnet
(
model
,
resolved_archive_file
)
return
load_tf_weights_in_xlnet
(
model
,
config
,
resolved_archive_file
)
# Load from a PyTorch state_dict
missing_keys
=
[]
unexpected_keys
=
[]
...
...
@@ -755,8 +763,8 @@ class XLNetPreTrainedModel(nn.Module):
if
child
is
not
None
:
load
(
child
,
prefix
+
name
+
'.'
)
start_prefix
=
''
if
not
hasattr
(
model
,
'
xlnet
'
)
and
any
(
s
.
startswith
(
'
xlnet.
'
)
for
s
in
state_dict
.
keys
()):
start_prefix
=
'
xlnet
.'
if
not
hasattr
(
model
,
'
transformer
'
)
and
any
(
s
.
startswith
(
'
transformer
'
)
for
s
in
state_dict
.
keys
()):
start_prefix
=
'
transformer
.'
load
(
model
,
prefix
=
start_prefix
)
if
len
(
missing_keys
)
>
0
:
logger
.
info
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
...
...
@@ -989,10 +997,10 @@ class XLNetModel(XLNetPreTrainedModel):
output_h
=
self
.
dropout
(
word_emb_k
)
if
inp_q
is
not
None
:
if
target_mapping
is
not
None
:
word_emb_q
=
mask_emb
.
expand
(
target_mapping
.
shape
[
0
],
bsz
,
-
1
)
word_emb_q
=
self
.
mask_emb
.
expand
(
target_mapping
.
shape
[
0
],
bsz
,
-
1
)
else
:
inp_q_ext
=
inp_q
[:,
:,
None
]
word_emb_q
=
inp_q_ext
*
mask_emb
+
(
1
-
inp_q_ext
)
*
word_emb_k
word_emb_q
=
inp_q_ext
*
self
.
mask_emb
+
(
1
-
inp_q_ext
)
*
word_emb_k
output_g
=
self
.
dropout
(
word_emb_q
)
else
:
output_g
=
None
...
...
@@ -1062,19 +1070,26 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
This can be used to compute head importance metrics. Default: False
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see XLNet paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
seg_id: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding.
mems: [optional] a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer.
If None, no memory is used.
perm_mask: [optional] float32 Tensor in shape [bsz, len, len].
If perm_mask[k, i, j] = 0, i attend to j in batch k;
if perm_mask[k, i, j] = 1, i does not attend to j in batch k.
If None, each position attends to all the others.
target_mapping: [optional] float32 Tensor in shape [bsz, num_predict, len].
If target_mapping[k, i, j] = 1, the i-th predict in batch k is
on the j-th token.
Only used during pretraining for partial prediction.
Set to None during finetuning.
inp_q: [optional] float32 Tensor in shape [bsz, len].
1 for tokens with losses and 0 for tokens without losses.
Only used during pretraining for two-stream attention.
Set to None during finetuning.
Outputs: Tuple of (encoded_layers, pooled_output)
...
...
pytorch_pretrained_bert/tokenization_xlnet.py
View file @
ebd2cb8d
...
...
@@ -37,6 +37,11 @@ VOCAB_NAME = 'spiece.model'
SPECIAL_TOKENS_NAME
=
'special_tokens.txt'
SPIECE_UNDERLINE
=
'▁'
SEG_ID_A
=
0
SEG_ID_B
=
1
SEG_ID_CLS
=
2
SEG_ID_SEP
=
3
SEG_ID_PAD
=
4
class
XLNetTokenizer
(
object
):
"""
...
...
@@ -52,6 +57,16 @@ class XLNetTokenizer(object):
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
vocab_file
=
PRETRAINED_VOCAB_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
special_tokens_file
=
None
if
'-cased'
in
pretrained_model_name_or_path
and
kwargs
.
get
(
'do_lower_case'
,
True
):
logger
.
warning
(
"The pre-trained model you are loading is a cased model but you have not set "
"`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
"you may want to check this behavior."
)
kwargs
[
'do_lower_case'
]
=
False
elif
'-cased'
not
in
pretrained_model_name_or_path
and
not
kwargs
.
get
(
'do_lower_case'
,
True
):
logger
.
warning
(
"The pre-trained model you are loading is an uncased model but you have set "
"`do_lower_case` to False. We are setting `do_lower_case=True` for you "
"but you may want to check this behavior."
)
kwargs
[
'do_lower_case'
]
=
True
else
:
vocab_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
VOCAB_NAME
)
special_tokens_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
SPECIAL_TOKENS_NAME
)
...
...
tests/modeling_xlnet_test.py
View file @
ebd2cb8d
...
...
@@ -78,23 +78,30 @@ class XLNetModelTest(unittest.TestCase):
input_ids_2
=
XLNetModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
segment_ids
=
XLNetModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
# inp_k: int32 Tensor in shape [len, bsz], the input token IDs.
# seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
# input_mask: float32 Tensor in shape [len, bsz], the input mask.
input_ids_q
=
XLNetModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
+
1
],
self
.
vocab_size
)
perm_mask
=
torch
.
zeros
(
self
.
batch_size
,
self
.
seq_length
+
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
)
perm_mask
[:,
:,
-
1
]
=
1.0
# Previous tokens don't see last token
target_mapping
=
torch
.
zeros
(
self
.
batch_size
,
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
)
target_mapping
[:,
0
,
-
1
]
=
1.0
# predict last token
inp_q
=
target_mapping
[:,
0
,
:].
clone
()
# predict last token
# inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
# seg_id: int32 Tensor in shape [bsz, len], the input segment IDs.
# input_mask: float32 Tensor in shape [bsz, len], the input mask.
# 0 for real tokens and 1 for padding.
# mems: a list of float32 Tensors in shape [mem_len,
bsz,
d_model], memory
# mems: a list of float32 Tensors in shape [
bsz,
mem_len, d_model], memory
# from previous batches. The length of the list equals n_layer.
# If None, no memory is used.
# perm_mask: float32 Tensor in shape [len, len
, bsz
].
# If perm_mask[i, j
, k
] = 0, i attend to j in batch k;
# if perm_mask[i, j
, k
] = 1, i does not attend to j in batch k.
# perm_mask: float32 Tensor in shape [
bsz,
len, len].
# If perm_mask[
k,
i, j] = 0, i attend to j in batch k;
# if perm_mask[
k,
i, j] = 1, i does not attend to j in batch k.
# If None, each position attends to all the others.
# target_mapping: float32 Tensor in shape [num_predict, len
, bsz
].
# If target_mapping[i, j
, k
] = 1, the i-th predict in batch k is
# target_mapping: float32 Tensor in shape [
bsz,
num_predict, len].
# If target_mapping[
k,
i, j] = 1, the i-th predict in batch k is
# on the j-th token.
# Only used during pretraining for partial prediction.
# Set to None during finetuning.
# inp_q: float32 Tensor in shape [
len, bsz
].
# inp_q: float32 Tensor in shape [
bsz, len
].
# 1 for tokens with losses and 0 for tokens without losses.
# Only used during pretraining for two-stream attention.
# Set to None during finetuning.
...
...
@@ -121,30 +128,35 @@ class XLNetModelTest(unittest.TestCase):
config
.
update
(
run_config
)
return
(
config
,
input_ids_1
,
input_ids_2
,
segment_ids
,
lm_labels
)
return
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
)
def
set_seed
(
self
):
random
.
seed
(
self
.
seed
)
torch
.
manual_seed
(
self
.
seed
)
def
create_transfo_xl_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
segment_ids
,
lm_labels
):
def
create_transfo_xl_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
):
model
=
XLNetLMHeadModel
(
config
)
model
.
eval
()
loss_1
,
mems_1a
=
model
(
input_ids_1
,
seg_id
=
segment_ids
,
target
=
lm_labels
)
l
m
_logits_1
,
mems_1b
=
model
(
input_ids_1
,
seg_id
=
segment_ids
)
al
l_logits_1
,
mems_1b
=
model
(
input_ids_1
,
seg_id
=
segment_ids
)
loss_2
,
mems_2a
=
model
(
input_ids_2
,
seg_id
=
segment_ids
,
target
=
lm_labels
,
mems
=
mems_1a
)
lm_logits_2
,
mems_2b
=
model
(
input_ids_2
,
seg_id
=
segment_ids
,
mems
=
mems_1b
)
all_logits_2
,
mems_2b
=
model
(
input_ids_2
,
seg_id
=
segment_ids
,
mems
=
mems_1b
)
logits
,
_
=
model
(
input_ids_q
,
perm_mask
=
perm_mask
,
target_mapping
=
target_mapping
,
inp_q
=
inp_q
)
outputs
=
{
"loss_1"
:
loss_1
,
"mems_1a"
:
mems_1a
,
"l
m
_logits_1"
:
l
m
_logits_1
,
"
al
l_logits_1"
:
al
l_logits_1
,
"mems_1b"
:
mems_1b
,
"loss_2"
:
loss_2
,
"mems_2a"
:
mems_2a
,
"l
m
_logits_2"
:
l
m
_logits_2
,
"
al
l_logits_2"
:
al
l_logits_2
,
"mems_2b"
:
mems_2b
,
}
return
outputs
...
...
@@ -154,7 +166,7 @@ class XLNetModelTest(unittest.TestCase):
list
(
result
[
"loss_1"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"l
m
_logits_1"
].
size
()),
list
(
result
[
"
al
l_logits_1"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1a"
]),
...
...
@@ -170,7 +182,7 @@ class XLNetModelTest(unittest.TestCase):
list
(
result
[
"loss_2"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"l
m
_logits_2"
].
size
()),
list
(
result
[
"
al
l_logits_2"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2a"
]),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment