Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c306869e
Commit
c306869e
authored
Feb 07, 2019
by
thomwolf
Browse files
add two transformer xl models
parent
d482e3d7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
175 additions
and
65 deletions
+175
-65
pytorch_pretrained_bert/__init__.py
pytorch_pretrained_bert/__init__.py
+1
-1
pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py
...etrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py
+5
-5
pytorch_pretrained_bert/modeling_transfo_xl.py
pytorch_pretrained_bert/modeling_transfo_xl.py
+169
-59
No files found.
pytorch_pretrained_bert/__init__.py
View file @
c306869e
...
...
@@ -11,7 +11,7 @@ from .modeling import (BertConfig, BertModel, BertForPreTraining,
from
.modeling_openai
import
(
OpenAIGPTConfig
,
OpenAIGPTModel
,
OpenAIGPTLMHeadModel
,
OpenAIGPTDoubleHeadsModel
,
load_tf_weights_in_openai_gpt
)
from
.modeling_transfo_xl
import
(
TransfoXLConfig
,
TransfoXLModel
,
from
.modeling_transfo_xl
import
(
TransfoXLConfig
,
TransfoXLModel
,
TransfoXLLMHeadModel
,
load_tf_weights_in_transfo_xl
)
from
.optimization
import
BertAdam
...
...
pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py
View file @
c306869e
...
...
@@ -27,7 +27,7 @@ import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils
from
pytorch_pretrained_bert.modeling_transfo_xl
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
TransfoXLConfig
,
TransfoXLModel
,
TransfoXL
LMHead
Model
,
load_tf_weights_in_transfo_xl
)
from
pytorch_pretrained_bert.tokenization_transfo_xl
import
(
CORPUS_NAME
,
VOCAB_NAME
)
...
...
@@ -37,7 +37,7 @@ if sys.version_info[0] == 2:
else
:
import
pickle
# We do this to be able to load
the
python 2 datasets pickles
# We do this to be able to load python 2 datasets pickles
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
data_utils
.
Vocab
=
data_utils
.
TransfoXLTokenizer
data_utils
.
Corpus
=
data_utils
.
TransfoXLCorpus
...
...
@@ -49,6 +49,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
pytorch_dump_folder_path
,
transfo_xl_dataset_file
):
if
transfo_xl_dataset_file
:
# Convert a pre-processed corpus (see original TensorFlow repo)
with
open
(
transfo_xl_dataset_file
,
"rb"
)
as
fp
:
corpus
=
pickle
.
load
(
fp
,
encoding
=
"latin1"
)
# Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
...
...
@@ -64,18 +65,18 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
torch
.
save
(
corpus_dict_no_vocab
,
pytorch_dataset_dump_path
)
if
tf_checkpoint_path
:
# Convert a pre-trained TensorFlow model
config_path
=
os
.
path
.
abspath
(
transfo_xl_config_file
)
tf_path
=
os
.
path
.
abspath
(
tf_checkpoint_path
)
print
(
"Converting Transformer XL checkpoint from {} with config at {}"
.
format
(
tf_path
,
config_path
))
# Initialise PyTorch model
# Construct model
if
transfo_xl_config_file
==
""
:
config
=
TransfoXLConfig
()
else
:
config
=
TransfoXLConfig
(
transfo_xl_config_file
)
print
(
"Building PyTorch model from configuration: {}"
.
format
(
str
(
config
)))
model
=
TransfoXLModel
(
config
)
model
=
TransfoXL
LMHead
Model
(
config
)
model
=
load_tf_weights_in_transfo_xl
(
model
,
config
,
tf_path
)
# Save pytorch-model
...
...
@@ -90,7 +91,6 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--pytorch_dump_folder_path"
,
default
=
None
,
type
=
str
,
...
...
pytorch_pretrained_bert/modeling_transfo_xl.py
View file @
c306869e
...
...
@@ -57,7 +57,7 @@ def build_tf_to_pytorch_map(model, config):
This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible.
"""
tf_to_pt_map
=
{}
# Embeddings
cutoffs
# Embeddings
for
i
,
(
embed_l
,
proj_l
)
in
enumerate
(
zip
(
model
.
word_emb
.
emb_layers
,
model
.
word_emb
.
emb_projs
)):
layer_str
=
"transformer/adaptive_embed/cutoff_%d/"
%
i
tf_to_pt_map
.
update
({
...
...
@@ -934,11 +934,11 @@ class TransfoXLPreTrainedModel(nn.Module):
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
if
state_dict
is
None
and
not
from_tf
:
state_dict
=
torch
.
load
(
resolved_archive_file
)
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
if
not
torch
.
cuda
.
is_available
()
else
None
)
if
from_tf
:
# Directly load from a TensorFlow checkpoint
weights_path
=
os
.
path
.
join
(
serialization_dir
,
TF_WEIGHTS_NAME
)
return
load_tf_weights_in_transfo_xl
(
model
,
weights_path
)
return
load_tf_weights_in_transfo_xl
(
model
,
config
,
pretrained_model_name_or_path
)
missing_keys
=
[]
unexpected_keys
=
[]
error_msgs
=
[]
...
...
@@ -965,18 +965,49 @@ class TransfoXLPreTrainedModel(nn.Module):
if
len
(
error_msgs
)
>
0
:
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
# Make sure we are still sharing the input and output embeddings
if
model
.
hasattr
(
'tie_weights'
):
model
.
tie_weights
()
return
model
class
TransfoXLModel
(
TransfoXLPreTrainedModel
):
"""Transformer XL model ("Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context").
Transformer XL use a relative positioning (with sinusiodal patterns) and adaptive softmax inputs which means that:
- you don't need to specify positioning embeddings indices
- the tokens in the vocabulary have to be sorted to decreasing frequency.
Params:
config: a TransfoXLConfig class instance with the configuration to build a new model
Inputs:
`input_ids`: a torch.LongTensor of shape [sequence_length, batch_size]
with the token indices selected in the range [0, self.config.n_token[
Outputs:
A tuple of (last_hidden_state, new_mems)
`last_hidden_state`: the encoded-hidden-states at the top of the model
as a torch.FloatTensor of size [sequence_length, batch_size, self.config.d_model]
`new_mems`: list (num layers) of updated mem states at the entry of each layer
each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]
Example usage:
```python
# Already been converted into BPE token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_ids_next = torch.LongTensor([[53, 21, 1], [64, 23, 100]])
config = TransfoXLConfig()
model = TransfoXLModel(config)
last_hidden_state, new_mems = model(input_ids)
# Another time on input_ids_next using the memory:
last_hidden_state, new_mems = model(input_ids_next, new_mems)
```
"""
def
__init__
(
self
,
config
):
# n_token, n_layer, n_head, d_model, d_head, d_inner,
# dropout, dropatt, tie_weight=True, d_embed=None,
# div_val=1, tie_projs=[False], pre_lnorm=False,
# tgt_len=None, ext_len=None, mem_len=None,
# cutoffs=[], adapt_inp=False, untie_r=False,
# same_length=False, attn_type=0, clamp_len=-1,
# sample_softmax=-1, **kwargs):
super
(
TransfoXLModel
,
self
).
__init__
(
config
)
self
.
n_token
=
config
.
n_token
...
...
@@ -1034,31 +1065,6 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
)
)
self
.
sample_softmax
=
config
.
sample_softmax
# use sampled softmax
if
config
.
sample_softmax
>
0
:
self
.
out_layer
=
nn
.
Linear
(
config
.
d_model
,
config
.
n_token
)
if
config
.
tie_weight
:
self
.
out_layer
.
weight
=
self
.
word_emb
.
weight
self
.
tie_weight
=
config
.
tie_weight
self
.
sampler
=
LogUniformSampler
(
config
.
n_token
,
config
.
sample_softmax
)
# use adaptive softmax (including standard softmax)
else
:
self
.
crit
=
ProjectedAdaptiveLogSoftmax
(
config
.
n_token
,
config
.
d_embed
,
config
.
d_model
,
config
.
cutoffs
,
div_val
=
config
.
div_val
)
if
config
.
tie_weight
:
for
i
in
range
(
len
(
self
.
crit
.
out_layers
)):
self
.
crit
.
out_layers
[
i
].
weight
=
self
.
word_emb
.
emb_layers
[
i
].
weight
if
config
.
tie_projs
:
for
i
,
tie_proj
in
enumerate
(
config
.
tie_projs
):
if
tie_proj
and
config
.
div_val
==
1
and
config
.
d_model
!=
config
.
d_embed
:
self
.
crit
.
out_projs
[
i
]
=
self
.
word_emb
.
emb_projs
[
0
]
elif
tie_proj
and
config
.
div_val
!=
1
:
self
.
crit
.
out_projs
[
i
]
=
self
.
word_emb
.
emb_projs
[
i
]
self
.
same_length
=
config
.
same_length
self
.
clamp_len
=
config
.
clamp_len
...
...
@@ -1074,6 +1080,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
elif
self
.
attn_type
==
3
:
# absolute deeper SA
self
.
r_emb
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_layer
,
self
.
max_klen
,
self
.
n_head
,
self
.
d_head
))
self
.
apply
(
self
.
init_weights
)
def
backward_compatible
(
self
):
self
.
sample_softmax
=
-
1
...
...
@@ -1210,32 +1217,135 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
return
core_out
,
new_mems
def
forward
(
self
,
data
,
target
=
None
,
*
mems
):
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
# So, have to initialize size(0) mems inside the model forward.
# Moreover, have to return new_mems to allow nn.DataParallel to piece
# them together.
if
not
mems
:
mems
=
self
.
init_mems
(
data
)
hidden
,
new_mems
=
self
.
_forward
(
data
,
mems
=
mems
)
if
target
is
None
:
if
new_mems
is
None
:
return
[
hidden
]
def
forward
(
self
,
input_ids
,
mems
=
None
):
""" Params:
input_ids :: [len, bsz]
Returns:
tuple (last_hidden, new_mems) where:
new_mems: list (num layers) of mem states at the entry of each layer
shape :: [self.config.mem_len, bsz, self.config.d_model]
last_hidden: output of the last layer:
shape :: [len, bsz, self.config.d_model]
"""
if
mems
is
None
:
mems
=
self
.
init_mems
(
input_ids
)
last_hidden
,
new_mems
=
self
.
_forward
(
input_ids
,
mems
=
mems
)
return
(
last_hidden
,
new_mems
)
class
TransfoXLLMHeadModel
(
TransfoXLPreTrainedModel
):
"""Transformer XL model ("Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context").
This model add an (adaptive) softmax head on top of the TransfoXLModel
Transformer XL use a relative positioning (with sinusiodal patterns) and adaptive softmax inputs which means that:
- you don't need to specify positioning embeddings indices
- the tokens in the vocabulary have to be sorted to decreasing frequency.
Call self.tie_weights() if you update/load the weights of the transformer to keep the weights tied.
Params:
config: a TransfoXLConfig class instance with the configuration to build a new model
Inputs:
`input_ids`: a torch.LongTensor of shape [sequence_length, batch_size]
with the token indices selected in the range [0, self.config.n_token[
`target`: a torch.LongTensor of shape [sequence_length, batch_size]
with the target token indices selected in the range [0, self.config.n_token[
Outputs:
A tuple of (last_hidden_state, new_mems)
`softmax_output`: output of the (adaptive) softmax:
if target is None:
Negative log likelihood of shape :: [len, bsz]
else:
return
[
hidden
]
+
new_mems
log probabilities of tokens, shape :: [len, bsz, n_tokens]
`new_mems`: list (num layers) of updated mem states at the entry of each layer
each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]
tgt_len
=
target
.
size
(
0
)
pred_hid
=
hidden
[
-
tgt_len
:]
Example usage:
```python
# Already been converted into BPE token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_ids_next = torch.LongTensor([[53, 21, 1], [64, 23, 100]])
config = TransfoXLConfig()
model = TransfoXLModel(config)
last_hidden_state, new_mems = model(input_ids)
# Another time on input_ids_next using the memory:
last_hidden_state, new_mems = model(input_ids_next, new_mems)
```
"""
def
__init__
(
self
,
config
):
super
(
TransfoXLLMHeadModel
,
self
).
__init__
(
config
)
self
.
transformer
=
TransfoXLModel
(
config
)
self
.
sample_softmax
=
config
.
sample_softmax
# use sampled softmax
if
config
.
sample_softmax
>
0
:
self
.
out_layer
=
nn
.
Linear
(
config
.
d_model
,
config
.
n_token
)
self
.
sampler
=
LogUniformSampler
(
config
.
n_token
,
config
.
sample_softmax
)
# use adaptive softmax (including standard softmax)
else
:
self
.
crit
=
ProjectedAdaptiveLogSoftmax
(
config
.
n_token
,
config
.
d_embed
,
config
.
d_model
,
config
.
cutoffs
,
div_val
=
config
.
div_val
)
self
.
apply
(
self
.
init_weights
)
self
.
tie_weights
()
def
tie_weights
(
self
):
""" Run this to be sure output and input (adaptive) softmax weights are tied """
# sampled softmax
if
self
.
sample_softmax
>
0
:
if
self
.
config
.
tie_weight
:
self
.
out_layer
.
weight
=
self
.
transformer
.
word_emb
.
weight
# adaptive softmax (including standard softmax)
else
:
if
self
.
config
.
tie_weight
:
for
i
in
range
(
len
(
self
.
crit
.
out_layers
)):
self
.
crit
.
out_layers
[
i
].
weight
=
self
.
transformer
.
word_emb
.
emb_layers
[
i
].
weight
if
self
.
config
.
tie_projs
:
for
i
,
tie_proj
in
enumerate
(
self
.
config
.
tie_projs
):
if
tie_proj
and
self
.
config
.
div_val
==
1
and
self
.
config
.
d_model
!=
self
.
config
.
d_embed
:
self
.
crit
.
out_projs
[
i
]
=
self
.
transformer
.
word_emb
.
emb_projs
[
0
]
elif
tie_proj
and
self
.
config
.
div_val
!=
1
:
self
.
crit
.
out_projs
[
i
]
=
self
.
transformer
.
word_emb
.
emb_projs
[
i
]
def
reset_length
(
self
,
tgt_len
,
ext_len
,
mem_len
):
self
.
transformer
.
reset_length
(
tgt_len
,
ext_len
,
mem_len
)
def
init_mems
(
self
,
data
):
return
self
.
transformer
.
init_mems
(
data
)
def
forward
(
self
,
input_ids
,
target
=
None
,
mems
=
None
):
""" Params:
input_ids :: [len, bsz]
target :: [len, bsz]
Returns:
tuple(softmax_output, new_mems) where:
new_mems: list (num layers) of hidden states at the entry of each layer
shape :: [mem_len, bsz, self.config.d_model]
softmax_output: output of the (adaptive) softmax:
if target is None:
Negative log likelihood of shape :: [len, bsz]
else:
log probabilities of tokens, shape :: [len, bsz, n_tokens]
"""
bsz
=
input_ids
.
size
(
1
)
tgt_len
=
input_ids
.
size
(
0
)
last_hidden
,
new_mems
=
self
.
transformer
(
input_ids
,
mems
)
pred_hid
=
last_hidden
[
-
tgt_len
:]
if
self
.
sample_softmax
>
0
and
self
.
training
:
assert
self
.
tie_weight
logit
=
sample_logits
(
self
.
word_emb
,
self
.
out_layer
.
bias
,
target
,
pred_hid
,
self
.
sampler
)
assert
self
.
config
.
tie_weight
logit
=
sample_logits
(
self
.
transformer
.
word_emb
,
self
.
out_layer
.
bias
,
target
,
pred_hid
,
self
.
sampler
)
loss
=
-
F
.
log_softmax
(
logit
,
-
1
)[:,
:,
0
]
else
:
loss
=
self
.
crit
(
pred_hid
.
view
(
-
1
,
pred_hid
.
size
(
-
1
)),
target
.
view
(
-
1
))
loss
=
loss
.
view
(
tgt_len
,
-
1
)
softmax_output
=
self
.
crit
(
pred_hid
.
view
(
-
1
,
pred_hid
.
size
(
-
1
)),
target
)
if
target
is
None
:
softmax_output
=
softmax_output
.
view
(
tgt_len
,
bsz
,
-
1
)
else
:
softmax_output
=
softmax_output
.
view
(
tgt_len
,
bsz
)
if
new_mems
is
None
:
return
[
loss
]
else
:
return
(
loss
,
new_mems
)
return
(
softmax_output
,
new_mems
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment