Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ebd2cb8d
Commit
ebd2cb8d
authored
Jun 21, 2019
by
thomwolf
Browse files
update from_pretrained to load XLNetModel as well
parent
483cbc36
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
99 additions
and
36 deletions
+99
-36
examples/generation_xlnet.py
examples/generation_xlnet.py
+21
-0
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+33
-18
pytorch_pretrained_bert/tokenization_xlnet.py
pytorch_pretrained_bert/tokenization_xlnet.py
+15
-0
tests/modeling_xlnet_test.py
tests/modeling_xlnet_test.py
+30
-18
No files found.
examples/generation_xlnet.py
0 → 100644
View file @
ebd2cb8d
import
torch
from
torch.nn
import
functional
as
F
from
pytorch_pretrained_bert
import
XLNetModel
,
XLNetLMHeadModel
,
XLNetTokenizer
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
)
tokenizer
=
XLNetTokenizer
.
from_pretrained
(
'xlnet-large-cased'
)
model
=
XLNetModel
.
from_pretrained
(
'xlnet-large-cased'
)
model
=
XLNetLMHeadModel
.
from_pretrained
(
'xlnet-large-cased'
)
tokens
=
tokenizer
.
encode
(
'I am very '
)
for
i
in
range
(
len
(
tokens
),
20
):
mask
=
torch
.
tensor
([[[
0.0
]
*
i
+
[
1.0
]]])
logits
,
_
=
model
(
torch
.
tensor
([
tokens
+
[
0
]]),
perm_mask
=
mask
.
expand
(
-
1
,
i
+
1
,
-
1
),
target_mapping
=
mask
,
inp_q
=
mask
.
squeeze
(
1
))
output
=
torch
.
multinomial
(
F
.
softmax
(
logits
[
0
,
0
,
:]),
1
)
tokens
.
append
(
output
.
item
())
print
(
tokenizer
.
decode
(
tokens
))
pytorch_pretrained_bert/modeling_xlnet.py
View file @
ebd2cb8d
...
@@ -727,16 +727,24 @@ class XLNetPreTrainedModel(nn.Module):
...
@@ -727,16 +727,24 @@ class XLNetPreTrainedModel(nn.Module):
archive_file
,
resolved_archive_file
))
archive_file
,
resolved_archive_file
))
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
config_file
,
resolved_config_file
))
config_file
,
resolved_config_file
))
# Load config
# Load config
config
=
XLNetConfig
.
from_json_file
(
resolved_config_file
)
config
=
XLNetConfig
.
from_json_file
(
resolved_config_file
)
logger
.
info
(
"Model config {}"
.
format
(
config
))
logger
.
info
(
"Model config {}"
.
format
(
config
))
# Update config with kwargs if needed
for
key
,
value
in
kwargs
:
if
hasattr
(
config
,
key
):
setattr
(
config
,
key
,
value
)
# Instantiate model.
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
if
state_dict
is
None
and
not
from_tf
:
if
state_dict
is
None
and
not
from_tf
:
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
if
from_tf
:
if
from_tf
:
# Directly load from a TensorFlow checkpoint
# Directly load from a TensorFlow checkpoint
return
load_tf_weights_in_xlnet
(
model
,
resolved_archive_file
)
return
load_tf_weights_in_xlnet
(
model
,
config
,
resolved_archive_file
)
# Load from a PyTorch state_dict
# Load from a PyTorch state_dict
missing_keys
=
[]
missing_keys
=
[]
unexpected_keys
=
[]
unexpected_keys
=
[]
...
@@ -755,8 +763,8 @@ class XLNetPreTrainedModel(nn.Module):
...
@@ -755,8 +763,8 @@ class XLNetPreTrainedModel(nn.Module):
if
child
is
not
None
:
if
child
is
not
None
:
load
(
child
,
prefix
+
name
+
'.'
)
load
(
child
,
prefix
+
name
+
'.'
)
start_prefix
=
''
start_prefix
=
''
if
not
hasattr
(
model
,
'
xlnet
'
)
and
any
(
s
.
startswith
(
'
xlnet.
'
)
for
s
in
state_dict
.
keys
()):
if
not
hasattr
(
model
,
'
transformer
'
)
and
any
(
s
.
startswith
(
'
transformer
'
)
for
s
in
state_dict
.
keys
()):
start_prefix
=
'
xlnet
.'
start_prefix
=
'
transformer
.'
load
(
model
,
prefix
=
start_prefix
)
load
(
model
,
prefix
=
start_prefix
)
if
len
(
missing_keys
)
>
0
:
if
len
(
missing_keys
)
>
0
:
logger
.
info
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
logger
.
info
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
...
@@ -989,10 +997,10 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -989,10 +997,10 @@ class XLNetModel(XLNetPreTrainedModel):
output_h
=
self
.
dropout
(
word_emb_k
)
output_h
=
self
.
dropout
(
word_emb_k
)
if
inp_q
is
not
None
:
if
inp_q
is
not
None
:
if
target_mapping
is
not
None
:
if
target_mapping
is
not
None
:
word_emb_q
=
mask_emb
.
expand
(
target_mapping
.
shape
[
0
],
bsz
,
-
1
)
word_emb_q
=
self
.
mask_emb
.
expand
(
target_mapping
.
shape
[
0
],
bsz
,
-
1
)
else
:
else
:
inp_q_ext
=
inp_q
[:,
:,
None
]
inp_q_ext
=
inp_q
[:,
:,
None
]
word_emb_q
=
inp_q_ext
*
mask_emb
+
(
1
-
inp_q_ext
)
*
word_emb_k
word_emb_q
=
inp_q_ext
*
self
.
mask_emb
+
(
1
-
inp_q_ext
)
*
word_emb_k
output_g
=
self
.
dropout
(
word_emb_q
)
output_g
=
self
.
dropout
(
word_emb_q
)
else
:
else
:
output_g
=
None
output_g
=
None
...
@@ -1062,19 +1070,26 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -1062,19 +1070,26 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
This can be used to compute head importance metrics. Default: False
This can be used to compute head importance metrics. Default: False
Inputs:
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
seg_id: int32 Tensor in shape [bsz, len], the input segment IDs.
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
0 for real tokens and 1 for padding.
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
mems: [optional] a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
a `sentence B` token (see XLNet paper for more details).
from previous batches. The length of the list equals n_layer.
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
If None, no memory is used.
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
perm_mask: [optional] float32 Tensor in shape [bsz, len, len].
input sequence length in the current batch. It's the mask that we typically use for attention when
If perm_mask[k, i, j] = 0, i attend to j in batch k;
a batch has varying length sentences.
if perm_mask[k, i, j] = 1, i does not attend to j in batch k.
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
If None, each position attends to all the others.
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
target_mapping: [optional] float32 Tensor in shape [bsz, num_predict, len].
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
If target_mapping[k, i, j] = 1, the i-th predict in batch k is
on the j-th token.
Only used during pretraining for partial prediction.
Set to None during finetuning.
inp_q: [optional] float32 Tensor in shape [bsz, len].
1 for tokens with losses and 0 for tokens without losses.
Only used during pretraining for two-stream attention.
Set to None during finetuning.
Outputs: Tuple of (encoded_layers, pooled_output)
Outputs: Tuple of (encoded_layers, pooled_output)
...
...
pytorch_pretrained_bert/tokenization_xlnet.py
View file @
ebd2cb8d
...
@@ -37,6 +37,11 @@ VOCAB_NAME = 'spiece.model'
...
@@ -37,6 +37,11 @@ VOCAB_NAME = 'spiece.model'
SPECIAL_TOKENS_NAME
=
'special_tokens.txt'
SPECIAL_TOKENS_NAME
=
'special_tokens.txt'
SPIECE_UNDERLINE
=
'▁'
SPIECE_UNDERLINE
=
'▁'
SEG_ID_A
=
0
SEG_ID_B
=
1
SEG_ID_CLS
=
2
SEG_ID_SEP
=
3
SEG_ID_PAD
=
4
class
XLNetTokenizer
(
object
):
class
XLNetTokenizer
(
object
):
"""
"""
...
@@ -52,6 +57,16 @@ class XLNetTokenizer(object):
...
@@ -52,6 +57,16 @@ class XLNetTokenizer(object):
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
vocab_file
=
PRETRAINED_VOCAB_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
vocab_file
=
PRETRAINED_VOCAB_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
special_tokens_file
=
None
special_tokens_file
=
None
if
'-cased'
in
pretrained_model_name_or_path
and
kwargs
.
get
(
'do_lower_case'
,
True
):
logger
.
warning
(
"The pre-trained model you are loading is a cased model but you have not set "
"`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
"you may want to check this behavior."
)
kwargs
[
'do_lower_case'
]
=
False
elif
'-cased'
not
in
pretrained_model_name_or_path
and
not
kwargs
.
get
(
'do_lower_case'
,
True
):
logger
.
warning
(
"The pre-trained model you are loading is an uncased model but you have set "
"`do_lower_case` to False. We are setting `do_lower_case=True` for you "
"but you may want to check this behavior."
)
kwargs
[
'do_lower_case'
]
=
True
else
:
else
:
vocab_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
VOCAB_NAME
)
vocab_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
VOCAB_NAME
)
special_tokens_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
SPECIAL_TOKENS_NAME
)
special_tokens_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
SPECIAL_TOKENS_NAME
)
...
...
tests/modeling_xlnet_test.py
View file @
ebd2cb8d
...
@@ -78,23 +78,30 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -78,23 +78,30 @@ class XLNetModelTest(unittest.TestCase):
input_ids_2
=
XLNetModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_2
=
XLNetModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
segment_ids
=
XLNetModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
segment_ids
=
XLNetModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
# inp_k: int32 Tensor in shape [len, bsz], the input token IDs.
input_ids_q
=
XLNetModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
+
1
],
self
.
vocab_size
)
# seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
perm_mask
=
torch
.
zeros
(
self
.
batch_size
,
self
.
seq_length
+
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
)
# input_mask: float32 Tensor in shape [len, bsz], the input mask.
perm_mask
[:,
:,
-
1
]
=
1.0
# Previous tokens don't see last token
target_mapping
=
torch
.
zeros
(
self
.
batch_size
,
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
)
target_mapping
[:,
0
,
-
1
]
=
1.0
# predict last token
inp_q
=
target_mapping
[:,
0
,
:].
clone
()
# predict last token
# inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
# seg_id: int32 Tensor in shape [bsz, len], the input segment IDs.
# input_mask: float32 Tensor in shape [bsz, len], the input mask.
# 0 for real tokens and 1 for padding.
# 0 for real tokens and 1 for padding.
# mems: a list of float32 Tensors in shape [mem_len,
bsz,
d_model], memory
# mems: a list of float32 Tensors in shape [
bsz,
mem_len, d_model], memory
# from previous batches. The length of the list equals n_layer.
# from previous batches. The length of the list equals n_layer.
# If None, no memory is used.
# If None, no memory is used.
# perm_mask: float32 Tensor in shape [len, len
, bsz
].
# perm_mask: float32 Tensor in shape [
bsz,
len, len].
# If perm_mask[i, j
, k
] = 0, i attend to j in batch k;
# If perm_mask[
k,
i, j] = 0, i attend to j in batch k;
# if perm_mask[i, j
, k
] = 1, i does not attend to j in batch k.
# if perm_mask[
k,
i, j] = 1, i does not attend to j in batch k.
# If None, each position attends to all the others.
# If None, each position attends to all the others.
# target_mapping: float32 Tensor in shape [num_predict, len
, bsz
].
# target_mapping: float32 Tensor in shape [
bsz,
num_predict, len].
# If target_mapping[i, j
, k
] = 1, the i-th predict in batch k is
# If target_mapping[
k,
i, j] = 1, the i-th predict in batch k is
# on the j-th token.
# on the j-th token.
# Only used during pretraining for partial prediction.
# Only used during pretraining for partial prediction.
# Set to None during finetuning.
# Set to None during finetuning.
# inp_q: float32 Tensor in shape [
len, bsz
].
# inp_q: float32 Tensor in shape [
bsz, len
].
# 1 for tokens with losses and 0 for tokens without losses.
# 1 for tokens with losses and 0 for tokens without losses.
# Only used during pretraining for two-stream attention.
# Only used during pretraining for two-stream attention.
# Set to None during finetuning.
# Set to None during finetuning.
...
@@ -121,30 +128,35 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -121,30 +128,35 @@ class XLNetModelTest(unittest.TestCase):
config
.
update
(
run_config
)
config
.
update
(
run_config
)
return
(
config
,
input_ids_1
,
input_ids_2
,
segment_ids
,
lm_labels
)
return
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
)
def
set_seed
(
self
):
def
set_seed
(
self
):
random
.
seed
(
self
.
seed
)
random
.
seed
(
self
.
seed
)
torch
.
manual_seed
(
self
.
seed
)
torch
.
manual_seed
(
self
.
seed
)
def
create_transfo_xl_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
segment_ids
,
lm_labels
):
def
create_transfo_xl_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
target_mapping
,
inp_q
,
segment_ids
,
lm_labels
):
model
=
XLNetLMHeadModel
(
config
)
model
=
XLNetLMHeadModel
(
config
)
model
.
eval
()
model
.
eval
()
loss_1
,
mems_1a
=
model
(
input_ids_1
,
seg_id
=
segment_ids
,
target
=
lm_labels
)
loss_1
,
mems_1a
=
model
(
input_ids_1
,
seg_id
=
segment_ids
,
target
=
lm_labels
)
l
m
_logits_1
,
mems_1b
=
model
(
input_ids_1
,
seg_id
=
segment_ids
)
al
l_logits_1
,
mems_1b
=
model
(
input_ids_1
,
seg_id
=
segment_ids
)
loss_2
,
mems_2a
=
model
(
input_ids_2
,
seg_id
=
segment_ids
,
target
=
lm_labels
,
mems
=
mems_1a
)
loss_2
,
mems_2a
=
model
(
input_ids_2
,
seg_id
=
segment_ids
,
target
=
lm_labels
,
mems
=
mems_1a
)
lm_logits_2
,
mems_2b
=
model
(
input_ids_2
,
seg_id
=
segment_ids
,
mems
=
mems_1b
)
all_logits_2
,
mems_2b
=
model
(
input_ids_2
,
seg_id
=
segment_ids
,
mems
=
mems_1b
)
logits
,
_
=
model
(
input_ids_q
,
perm_mask
=
perm_mask
,
target_mapping
=
target_mapping
,
inp_q
=
inp_q
)
outputs
=
{
outputs
=
{
"loss_1"
:
loss_1
,
"loss_1"
:
loss_1
,
"mems_1a"
:
mems_1a
,
"mems_1a"
:
mems_1a
,
"l
m
_logits_1"
:
l
m
_logits_1
,
"
al
l_logits_1"
:
al
l_logits_1
,
"mems_1b"
:
mems_1b
,
"mems_1b"
:
mems_1b
,
"loss_2"
:
loss_2
,
"loss_2"
:
loss_2
,
"mems_2a"
:
mems_2a
,
"mems_2a"
:
mems_2a
,
"l
m
_logits_2"
:
l
m
_logits_2
,
"
al
l_logits_2"
:
al
l_logits_2
,
"mems_2b"
:
mems_2b
,
"mems_2b"
:
mems_2b
,
}
}
return
outputs
return
outputs
...
@@ -154,7 +166,7 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -154,7 +166,7 @@ class XLNetModelTest(unittest.TestCase):
list
(
result
[
"loss_1"
].
size
()),
list
(
result
[
"loss_1"
].
size
()),
[])
[])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"l
m
_logits_1"
].
size
()),
list
(
result
[
"
al
l_logits_1"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1a"
]),
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1a"
]),
...
@@ -170,7 +182,7 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -170,7 +182,7 @@ class XLNetModelTest(unittest.TestCase):
list
(
result
[
"loss_2"
].
size
()),
list
(
result
[
"loss_2"
].
size
()),
[])
[])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"l
m
_logits_2"
].
size
()),
list
(
result
[
"
al
l_logits_2"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2a"
]),
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2a"
]),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment