Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8831c688
Commit
8831c688
authored
Jan 16, 2019
by
thomwolf
Browse files
fixing various parts of model conversion, loading and weights sharing
parent
bcd4aa8f
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
243 additions
and
283 deletions
+243
-283
examples/eval_transfo_xl.py
examples/eval_transfo_xl.py
+1
-1
pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py
...etrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py
+3
-2
pytorch_pretrained_bert/modeling_transfo_xl.py
pytorch_pretrained_bert/modeling_transfo_xl.py
+233
-280
pytorch_pretrained_bert/tokenization_transfo_xl.py
pytorch_pretrained_bert/tokenization_transfo_xl.py
+6
-0
No files found.
examples/eval_transfo_xl.py
View file @
8831c688
...
@@ -42,7 +42,7 @@ parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model
...
@@ -42,7 +42,7 @@ parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model
# parser.add_argument('--data', type=str, default='../data/wikitext-103',
# parser.add_argument('--data', type=str, default='../data/wikitext-103',
# help='location of the data corpus')
# help='location of the data corpus')
parser
.
add_argument
(
'--model_name'
,
type
=
str
,
default
=
'transfo-xl-wt103'
,
parser
.
add_argument
(
'--model_name'
,
type
=
str
,
default
=
'transfo-xl-wt103'
,
choices
=
[
'transfo-xl-wt103'
],
#, 'lm1b', 'enwik8', 'text8'],
#
choices=['transfo-xl-wt103'], #, 'lm1b', 'enwik8', 'text8'],
help
=
'pretrained model name'
)
help
=
'pretrained model name'
)
parser
.
add_argument
(
'--split'
,
type
=
str
,
default
=
'test'
,
parser
.
add_argument
(
'--split'
,
type
=
str
,
default
=
'test'
,
choices
=
[
'all'
,
'valid'
,
'test'
],
choices
=
[
'all'
,
'valid'
,
'test'
],
...
...
pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py
View file @
8831c688
...
@@ -116,7 +116,8 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
...
@@ -116,7 +116,8 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
# Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
# Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
pytorch_vocab_dump_path
=
pytorch_dump_folder_path
+
'/'
+
VOCAB_NAME
pytorch_vocab_dump_path
=
pytorch_dump_folder_path
+
'/'
+
VOCAB_NAME
print
(
"Save vocabulary to {}"
.
format
(
pytorch_vocab_dump_path
))
print
(
"Save vocabulary to {}"
.
format
(
pytorch_vocab_dump_path
))
torch
.
save
(
corpus
.
vocab
.
__dict__
,
pytorch_vocab_dump_path
)
corpus_vocab_dict
=
corpus
.
vocab
.
__dict__
torch
.
save
(
corpus_vocab_dict
,
pytorch_vocab_dump_path
)
corpus_dict_no_vocab
=
corpus
.
__dict__
corpus_dict_no_vocab
=
corpus
.
__dict__
corpus_dict_no_vocab
.
pop
(
'vocab'
,
None
)
corpus_dict_no_vocab
.
pop
(
'vocab'
,
None
)
...
@@ -139,7 +140,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
...
@@ -139,7 +140,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
model
=
TransfoXLModel
(
config
)
model
=
TransfoXLModel
(
config
)
# Build TF to PyTorch weights loading map
# Build TF to PyTorch weights loading map
tf_to_pt_map
=
build_tf_to_pytorch_map
(
model
.
transformer
,
config
)
tf_to_pt_map
=
build_tf_to_pytorch_map
(
model
,
config
)
# Load weights from TF model
# Load weights from TF model
init_vars
=
tf
.
train
.
list_variables
(
tf_path
)
init_vars
=
tf
.
train
.
list_variables
(
tf_path
)
...
...
pytorch_pretrained_bert/modeling_transfo_xl.py
View file @
8831c688
...
@@ -30,6 +30,7 @@ import collections
...
@@ -30,6 +30,7 @@ import collections
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.nn
import
CrossEntropyLoss
from
torch.nn
import
CrossEntropyLoss
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
...
@@ -40,7 +41,10 @@ from .file_utils import cached_path
...
@@ -40,7 +41,10 @@ from .file_utils import cached_path
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_MODEL_ARCHIVE_MAP
=
{
PRETRAINED_MODEL_ARCHIVE_MAP
=
{
'transfo-xl-wt103'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103.tar.gz"
,
'transfo-xl-wt103'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-pytorch_model.bin"
,
}
PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
'transfo-xl-wt103'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-transfo_xl_config.json"
,
}
}
CONFIG_NAME
=
'transfo_xl_config.json'
CONFIG_NAME
=
'transfo_xl_config.json'
WEIGHTS_NAME
=
'pytorch_model.bin'
WEIGHTS_NAME
=
'pytorch_model.bin'
...
@@ -674,99 +678,266 @@ class AdaptiveEmbedding(nn.Module):
...
@@ -674,99 +678,266 @@ class AdaptiveEmbedding(nn.Module):
return
embed
return
embed
class
MemTransformerLM
(
nn
.
Module
):
def
__init__
(
self
,
n_token
,
n_layer
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
dropatt
,
tie_weight
=
True
,
d_embed
=
None
,
div_val
=
1
,
tie_projs
=
[
False
],
pre_lnorm
=
False
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
cutoffs
=
[],
adapt_inp
=
False
,
untie_r
=
False
,
same_length
=
False
,
attn_type
=
0
,
clamp_len
=-
1
,
sample_softmax
=-
1
,
**
kwargs
):
super
(
MemTransformerLM
,
self
).
__init__
()
self
.
n_token
=
n_token
d_embed
=
d_model
if
d_embed
is
None
else
d_embed
class
TransfoXLPreTrainedModel
(
nn
.
Module
):
self
.
d_embed
=
d_embed
""" An abstract class to handle weights initialization and
self
.
d_model
=
d_model
a simple interface for dowloading and loading pretrained models.
self
.
n_head
=
n_head
"""
self
.
d_head
=
d_head
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
TransfoXLPreTrainedModel
,
self
).
__init__
()
if
not
isinstance
(
config
,
TransfoXLConfig
):
raise
ValueError
(
"Parameter config in `{}(config)` should be an instance of class `TransfoXLConfig`. "
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
))
self
.
config
=
config
self
.
word_emb
=
AdaptiveEmbedding
(
n_token
,
d_embed
,
d_model
,
cutoffs
,
def
init_weight
(
self
,
weight
):
div_val
=
div_val
)
if
self
.
config
.
init
==
'uniform'
:
nn
.
init
.
uniform_
(
weight
,
-
self
.
config
.
init_range
,
self
.
config
.
init_range
)
elif
self
.
config
.
init
==
'normal'
:
nn
.
init
.
normal_
(
weight
,
0.0
,
self
.
config
.
init_std
)
self
.
drop
=
nn
.
Dropout
(
dropout
)
def
init_bias
(
self
,
bias
):
nn
.
init
.
constant_
(
bias
,
0.0
)
self
.
n_layer
=
n_layer
def
init_weights
(
self
,
m
):
""" Initialize the weights.
"""
classname
=
m
.
__class__
.
__name__
if
classname
.
find
(
'Linear'
)
!=
-
1
:
if
hasattr
(
m
,
'weight'
)
and
m
.
weight
is
not
None
:
self
.
init_weight
(
m
.
weight
)
if
hasattr
(
m
,
'bias'
)
and
m
.
bias
is
not
None
:
self
.
init_bias
(
m
.
bias
)
elif
classname
.
find
(
'AdaptiveEmbedding'
)
!=
-
1
:
if
hasattr
(
m
,
'emb_projs'
):
for
i
in
range
(
len
(
m
.
emb_projs
)):
if
m
.
emb_projs
[
i
]
is
not
None
:
nn
.
init
.
normal_
(
m
.
emb_projs
[
i
],
0.0
,
self
.
config
.
proj_init_std
)
elif
classname
.
find
(
'Embedding'
)
!=
-
1
:
if
hasattr
(
m
,
'weight'
):
self
.
init_weight
(
m
.
weight
)
elif
classname
.
find
(
'ProjectedAdaptiveLogSoftmax'
)
!=
-
1
:
if
hasattr
(
m
,
'cluster_weight'
)
and
m
.
cluster_weight
is
not
None
:
self
.
init_weight
(
m
.
cluster_weight
)
if
hasattr
(
m
,
'cluster_bias'
)
and
m
.
cluster_bias
is
not
None
:
self
.
init_bias
(
m
.
cluster_bias
)
if
hasattr
(
m
,
'out_projs'
):
for
i
in
range
(
len
(
m
.
out_projs
)):
if
m
.
out_projs
[
i
]
is
not
None
:
nn
.
init
.
normal_
(
m
.
out_projs
[
i
],
0.0
,
self
.
config
.
proj_init_std
)
elif
classname
.
find
(
'LayerNorm'
)
!=
-
1
:
if
hasattr
(
m
,
'weight'
):
nn
.
init
.
normal_
(
m
.
weight
,
1.0
,
self
.
config
.
init_std
)
if
hasattr
(
m
,
'bias'
)
and
m
.
bias
is
not
None
:
self
.
init_bias
(
m
.
bias
)
elif
classname
.
find
(
'TransformerLM'
)
!=
-
1
:
if
hasattr
(
m
,
'r_emb'
):
self
.
init_weight
(
m
.
r_emb
)
if
hasattr
(
m
,
'r_w_bias'
):
self
.
init_weight
(
m
.
r_w_bias
)
if
hasattr
(
m
,
'r_r_bias'
):
self
.
init_weight
(
m
.
r_r_bias
)
if
hasattr
(
m
,
'r_bias'
):
self
.
init_bias
(
m
.
r_bias
)
self
.
tgt_len
=
tgt_len
def
set_num_special_tokens
(
self
,
num_special_tokens
):
self
.
mem_len
=
mem_len
pass
self
.
ext_len
=
ext_len
self
.
max_klen
=
tgt_len
+
ext_len
+
mem_len
self
.
attn_type
=
attn_type
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
state_dict
=
None
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
"""
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `transfo-xl`
- a path or url to a pretrained model archive containing:
. `transfo_xl_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
"""
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
archive_file
=
PRETRAINED_MODEL_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
config_file
=
PRETRAINED_CONFIG_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CONFIG_NAME
)
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
except
FileNotFoundError
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
archive_file
,
config_file
))
return
None
if
resolved_archive_file
==
archive_file
and
resolved_config_file
==
config_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
else
:
logger
.
info
(
"loading weights file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
config_file
,
resolved_config_file
))
# Load config
config
=
TransfoXLConfig
.
from_json_file
(
resolved_config_file
)
logger
.
info
(
"Model config {}"
.
format
(
config
))
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
if
state_dict
is
None
:
state_dict
=
torch
.
load
(
resolved_archive_file
)
old_keys
=
[]
new_keys
=
[]
for
key
in
state_dict
.
keys
():
new_key
=
None
if
'gamma'
in
key
:
new_key
=
key
.
replace
(
'gamma'
,
'weight'
)
if
'beta'
in
key
:
new_key
=
key
.
replace
(
'beta'
,
'bias'
)
if
new_key
:
old_keys
.
append
(
key
)
new_keys
.
append
(
new_key
)
for
old_key
,
new_key
in
zip
(
old_keys
,
new_keys
):
state_dict
[
new_key
]
=
state_dict
.
pop
(
old_key
)
missing_keys
=
[]
unexpected_keys
=
[]
error_msgs
=
[]
# copy state_dict so _load_from_state_dict can modify it
metadata
=
getattr
(
state_dict
,
'_metadata'
,
None
)
state_dict
=
state_dict
.
copy
()
if
metadata
is
not
None
:
state_dict
.
_metadata
=
metadata
def
load
(
module
,
prefix
=
''
):
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
module
.
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
True
,
missing_keys
,
unexpected_keys
,
error_msgs
)
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
load
(
child
,
prefix
+
name
+
'.'
)
# load(model.transformer if hasattr(model, 'transformer') else model, prefix='')
if
len
(
missing_keys
)
>
0
:
logger
.
info
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
model
.
__class__
.
__name__
,
missing_keys
))
if
len
(
unexpected_keys
)
>
0
:
logger
.
info
(
"Weights from pretrained model not used in {}: {}"
.
format
(
model
.
__class__
.
__name__
,
unexpected_keys
))
if
len
(
error_msgs
)
>
0
:
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
return
model
class
TransfoXLModel
(
TransfoXLPreTrainedModel
):
def
__init__
(
self
,
config
):
# n_token, n_layer, n_head, d_model, d_head, d_inner,
# dropout, dropatt, tie_weight=True, d_embed=None,
# div_val=1, tie_projs=[False], pre_lnorm=False,
# tgt_len=None, ext_len=None, mem_len=None,
# cutoffs=[], adapt_inp=False, untie_r=False,
# same_length=False, attn_type=0, clamp_len=-1,
# sample_softmax=-1, **kwargs):
super
(
TransfoXLModel
,
self
).
__init__
(
config
)
self
.
n_token
=
config
.
n_token
self
.
d_embed
=
config
.
d_embed
self
.
d_model
=
config
.
d_model
self
.
n_head
=
config
.
n_head
self
.
d_head
=
config
.
d_head
self
.
word_emb
=
AdaptiveEmbedding
(
config
.
n_token
,
config
.
d_embed
,
config
.
d_model
,
config
.
cutoffs
,
div_val
=
config
.
div_val
)
self
.
drop
=
nn
.
Dropout
(
config
.
dropout
)
self
.
n_layer
=
config
.
n_layer
self
.
tgt_len
=
config
.
tgt_len
self
.
mem_len
=
config
.
mem_len
self
.
ext_len
=
config
.
ext_len
self
.
max_klen
=
config
.
tgt_len
+
config
.
ext_len
+
config
.
mem_len
if
not
untie_r
:
self
.
attn_type
=
config
.
attn_type
if
not
config
.
untie_r
:
self
.
r_w_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_head
,
self
.
d_head
))
self
.
r_w_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_head
,
self
.
d_head
))
self
.
r_r_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_head
,
self
.
d_head
))
self
.
r_r_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_head
,
self
.
d_head
))
self
.
layers
=
nn
.
ModuleList
()
self
.
layers
=
nn
.
ModuleList
()
if
attn_type
==
0
:
# the default attention
if
config
.
attn_type
==
0
:
# the default attention
for
i
in
range
(
n_layer
):
for
i
in
range
(
config
.
n_layer
):
self
.
layers
.
append
(
self
.
layers
.
append
(
RelPartialLearnableDecoderLayer
(
RelPartialLearnableDecoderLayer
(
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
config
.
n_head
,
config
.
d_model
,
config
.
d_head
,
config
.
d_inner
,
config
.
dropout
,
tgt_len
=
tgt_len
,
ext_len
=
ext_len
,
mem_len
=
mem_len
,
tgt_len
=
config
.
tgt_len
,
ext_len
=
config
.
ext_len
,
mem_len
=
config
.
mem_len
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
dropatt
=
config
.
dropatt
,
pre_lnorm
=
config
.
pre_lnorm
,
r_w_bias
=
None
if
untie_r
else
self
.
r_w_bias
,
r_w_bias
=
None
if
config
.
untie_r
else
self
.
r_w_bias
,
r_r_bias
=
None
if
untie_r
else
self
.
r_r_bias
)
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
)
)
)
elif
attn_type
==
1
:
# learnable embeddings
elif
config
.
attn_type
==
1
:
# learnable embeddings
for
i
in
range
(
n_layer
):
for
i
in
range
(
config
.
n_layer
):
self
.
layers
.
append
(
self
.
layers
.
append
(
RelLearnableDecoderLayer
(
RelLearnableDecoderLayer
(
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
config
.
n_head
,
config
.
d_model
,
config
.
d_head
,
config
.
d_inner
,
config
.
dropout
,
tgt_len
=
tgt_len
,
ext_len
=
ext_len
,
mem_len
=
mem_len
,
tgt_len
=
config
.
tgt_len
,
ext_len
=
config
.
ext_len
,
mem_len
=
config
.
mem_len
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
dropatt
=
config
.
dropatt
,
pre_lnorm
=
config
.
pre_lnorm
,
r_w_bias
=
None
if
untie_r
else
self
.
r_w_bias
,
r_w_bias
=
None
if
config
.
untie_r
else
self
.
r_w_bias
,
r_r_bias
=
None
if
untie_r
else
self
.
r_r_bias
)
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
)
)
)
elif
attn_type
in
[
2
,
3
]:
# absolute embeddings
elif
config
.
attn_type
in
[
2
,
3
]:
# absolute embeddings
for
i
in
range
(
n_layer
):
for
i
in
range
(
config
.
n_layer
):
self
.
layers
.
append
(
self
.
layers
.
append
(
DecoderLayer
(
DecoderLayer
(
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
config
.
n_head
,
config
.
d_model
,
config
.
d_head
,
config
.
d_inner
,
config
.
dropout
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
dropatt
=
config
.
dropatt
,
pre_lnorm
=
config
.
pre_lnorm
,
r_w_bias
=
None
if
untie_r
else
self
.
r_w_bias
,
r_w_bias
=
None
if
config
.
untie_r
else
self
.
r_w_bias
,
r_r_bias
=
None
if
untie_r
else
self
.
r_r_bias
)
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
)
)
)
self
.
sample_softmax
=
sample_softmax
self
.
sample_softmax
=
config
.
sample_softmax
# use sampled softmax
# use sampled softmax
if
sample_softmax
>
0
:
if
config
.
sample_softmax
>
0
:
self
.
out_layer
=
nn
.
Linear
(
d_model
,
n_token
)
self
.
out_layer
=
nn
.
Linear
(
config
.
d_model
,
config
.
n_token
)
if
tie_weight
:
if
config
.
tie_weight
:
self
.
out_layer
.
weight
=
self
.
word_emb
.
weight
self
.
out_layer
.
weight
=
self
.
word_emb
.
weight
self
.
tie_weight
=
tie_weight
self
.
tie_weight
=
config
.
tie_weight
self
.
sampler
=
LogUniformSampler
(
n_token
,
sample_softmax
)
self
.
sampler
=
LogUniformSampler
(
config
.
n_token
,
config
.
sample_softmax
)
# use adaptive softmax (including standard softmax)
# use adaptive softmax (including standard softmax)
else
:
else
:
self
.
crit
=
ProjectedAdaptiveLogSoftmax
(
n_token
,
d_embed
,
d_model
,
self
.
crit
=
ProjectedAdaptiveLogSoftmax
(
config
.
n_token
,
config
.
d_embed
,
config
.
d_model
,
cutoffs
,
div_val
=
div_val
)
config
.
cutoffs
,
div_val
=
config
.
div_val
)
if
tie_weight
:
if
config
.
tie_weight
:
for
i
in
range
(
len
(
self
.
crit
.
out_layers
)):
for
i
in
range
(
len
(
self
.
crit
.
out_layers
)):
self
.
crit
.
out_layers
[
i
].
weight
=
self
.
word_emb
.
emb_layers
[
i
].
weight
self
.
crit
.
out_layers
[
i
].
weight
=
self
.
word_emb
.
emb_layers
[
i
].
weight
if
tie_projs
:
if
config
.
tie_projs
:
for
i
,
tie_proj
in
enumerate
(
tie_projs
):
for
i
,
tie_proj
in
enumerate
(
config
.
tie_projs
):
if
tie_proj
and
div_val
==
1
and
d_model
!=
d_embed
:
if
tie_proj
and
config
.
div_val
==
1
and
config
.
d_model
!=
config
.
d_embed
:
self
.
crit
.
out_projs
[
i
]
=
self
.
word_emb
.
emb_projs
[
0
]
self
.
crit
.
out_projs
[
i
]
=
self
.
word_emb
.
emb_projs
[
0
]
elif
tie_proj
and
div_val
!=
1
:
elif
tie_proj
and
config
.
div_val
!=
1
:
self
.
crit
.
out_projs
[
i
]
=
self
.
word_emb
.
emb_projs
[
i
]
self
.
crit
.
out_projs
[
i
]
=
self
.
word_emb
.
emb_projs
[
i
]
self
.
same_length
=
same_length
self
.
same_length
=
config
.
same_length
self
.
clamp_len
=
clamp_len
self
.
clamp_len
=
config
.
clamp_len
if
self
.
attn_type
==
0
:
# default attention
if
self
.
attn_type
==
0
:
# default attention
self
.
pos_emb
=
PositionalEmbedding
(
self
.
d_model
)
self
.
pos_emb
=
PositionalEmbedding
(
self
.
d_model
)
...
@@ -859,8 +1030,7 @@ class MemTransformerLM(nn.Module):
...
@@ -859,8 +1030,7 @@ class MemTransformerLM(nn.Module):
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
for
i
,
layer
in
enumerate
(
self
.
layers
):
for
i
,
layer
in
enumerate
(
self
.
layers
):
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
core_out
=
layer
(
core_out
,
pos_emb
,
self
.
r_w_bias
,
core_out
=
layer
(
core_out
,
pos_emb
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
self
.
r_r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
)
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
elif
self
.
attn_type
==
1
:
# learnable
elif
self
.
attn_type
==
1
:
# learnable
core_out
=
self
.
drop
(
word_emb
)
core_out
=
self
.
drop
(
word_emb
)
...
@@ -949,220 +1119,3 @@ class MemTransformerLM(nn.Module):
...
@@ -949,220 +1119,3 @@ class MemTransformerLM(nn.Module):
else
:
else
:
return
[
loss
]
+
new_mems
return
[
loss
]
+
new_mems
class
TransfoXLPreTrainedModel
(
nn
.
Module
):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
TransfoXLPreTrainedModel
,
self
).
__init__
()
if
not
isinstance
(
config
,
TransfoXLConfig
):
raise
ValueError
(
"Parameter config in `{}(config)` should be an instance of class `TransfoXLConfig`. "
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
))
self
.
config
=
config
def
init_weight
(
self
,
weight
):
if
self
.
config
.
init
==
'uniform'
:
nn
.
init
.
uniform_
(
weight
,
-
self
.
config
.
init_range
,
self
.
config
.
init_range
)
elif
self
.
config
.
init
==
'normal'
:
nn
.
init
.
normal_
(
weight
,
0.0
,
self
.
config
.
init_std
)
def
init_bias
(
self
,
bias
):
nn
.
init
.
constant_
(
bias
,
0.0
)
def
init_weights
(
self
,
m
):
""" Initialize the weights.
"""
classname
=
m
.
__class__
.
__name__
if
classname
.
find
(
'Linear'
)
!=
-
1
:
if
hasattr
(
m
,
'weight'
)
and
m
.
weight
is
not
None
:
self
.
init_weight
(
m
.
weight
)
if
hasattr
(
m
,
'bias'
)
and
m
.
bias
is
not
None
:
self
.
init_bias
(
m
.
bias
)
elif
classname
.
find
(
'AdaptiveEmbedding'
)
!=
-
1
:
if
hasattr
(
m
,
'emb_projs'
):
for
i
in
range
(
len
(
m
.
emb_projs
)):
if
m
.
emb_projs
[
i
]
is
not
None
:
nn
.
init
.
normal_
(
m
.
emb_projs
[
i
],
0.0
,
self
.
config
.
proj_init_std
)
elif
classname
.
find
(
'Embedding'
)
!=
-
1
:
if
hasattr
(
m
,
'weight'
):
self
.
init_weight
(
m
.
weight
)
elif
classname
.
find
(
'ProjectedAdaptiveLogSoftmax'
)
!=
-
1
:
if
hasattr
(
m
,
'cluster_weight'
)
and
m
.
cluster_weight
is
not
None
:
self
.
init_weight
(
m
.
cluster_weight
)
if
hasattr
(
m
,
'cluster_bias'
)
and
m
.
cluster_bias
is
not
None
:
self
.
init_bias
(
m
.
cluster_bias
)
if
hasattr
(
m
,
'out_projs'
):
for
i
in
range
(
len
(
m
.
out_projs
)):
if
m
.
out_projs
[
i
]
is
not
None
:
nn
.
init
.
normal_
(
m
.
out_projs
[
i
],
0.0
,
self
.
config
.
proj_init_std
)
elif
classname
.
find
(
'LayerNorm'
)
!=
-
1
:
if
hasattr
(
m
,
'weight'
):
nn
.
init
.
normal_
(
m
.
weight
,
1.0
,
self
.
config
.
init_std
)
if
hasattr
(
m
,
'bias'
)
and
m
.
bias
is
not
None
:
self
.
init_bias
(
m
.
bias
)
elif
classname
.
find
(
'TransformerLM'
)
!=
-
1
:
if
hasattr
(
m
,
'r_emb'
):
self
.
init_weight
(
m
.
r_emb
)
if
hasattr
(
m
,
'r_w_bias'
):
self
.
init_weight
(
m
.
r_w_bias
)
if
hasattr
(
m
,
'r_r_bias'
):
self
.
init_weight
(
m
.
r_r_bias
)
if
hasattr
(
m
,
'r_bias'
):
self
.
init_bias
(
m
.
r_bias
)
def
set_num_special_tokens
(
self
,
num_special_tokens
):
pass
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name
,
state_dict
=
None
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
"""
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `transfo-xl`
- a path or url to a pretrained model archive containing:
. `transfo_xl_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
"""
if
pretrained_model_name
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
archive_file
=
PRETRAINED_MODEL_ARCHIVE_MAP
[
pretrained_model_name
]
else
:
archive_file
=
pretrained_model_name
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
except
FileNotFoundError
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name
,
', '
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
archive_file
))
return
None
if
resolved_archive_file
==
archive_file
:
logger
.
info
(
"loading archive file {}"
.
format
(
archive_file
))
else
:
logger
.
info
(
"loading archive file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
tempdir
=
None
if
os
.
path
.
isdir
(
resolved_archive_file
):
serialization_dir
=
resolved_archive_file
else
:
# Extract archive to temp dir
tempdir
=
tempfile
.
mkdtemp
()
logger
.
info
(
"extracting archive file {} to temp dir {}"
.
format
(
resolved_archive_file
,
tempdir
))
with
tarfile
.
open
(
resolved_archive_file
,
'r:gz'
)
as
archive
:
archive
.
extractall
(
tempdir
)
serialization_dir
=
tempdir
# Load config
config_file
=
os
.
path
.
join
(
serialization_dir
,
CONFIG_NAME
)
config
=
TransfoXLConfig
.
from_json_file
(
config_file
)
logger
.
info
(
"Model config {}"
.
format
(
config
))
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
if
state_dict
is
None
:
weights_path
=
os
.
path
.
join
(
serialization_dir
,
WEIGHTS_NAME
)
state_dict
=
torch
.
load
(
weights_path
)
old_keys
=
[]
new_keys
=
[]
for
key
in
state_dict
.
keys
():
new_key
=
None
if
'gamma'
in
key
:
new_key
=
key
.
replace
(
'gamma'
,
'weight'
)
if
'beta'
in
key
:
new_key
=
key
.
replace
(
'beta'
,
'bias'
)
if
new_key
:
old_keys
.
append
(
key
)
new_keys
.
append
(
new_key
)
for
old_key
,
new_key
in
zip
(
old_keys
,
new_keys
):
state_dict
[
new_key
]
=
state_dict
.
pop
(
old_key
)
missing_keys
=
[]
unexpected_keys
=
[]
error_msgs
=
[]
# copy state_dict so _load_from_state_dict can modify it
metadata
=
getattr
(
state_dict
,
'_metadata'
,
None
)
state_dict
=
state_dict
.
copy
()
if
metadata
is
not
None
:
state_dict
.
_metadata
=
metadata
def
load
(
module
,
prefix
=
''
):
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
module
.
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
True
,
missing_keys
,
unexpected_keys
,
error_msgs
)
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
load
(
child
,
prefix
+
name
+
'.'
)
# load(model.transformer if hasattr(model, 'transformer') else model, prefix='')
if
len
(
missing_keys
)
>
0
:
logger
.
info
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
model
.
__class__
.
__name__
,
missing_keys
))
if
len
(
unexpected_keys
)
>
0
:
logger
.
info
(
"Weights from pretrained model not used in {}: {}"
.
format
(
model
.
__class__
.
__name__
,
unexpected_keys
))
if
len
(
error_msgs
)
>
0
:
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
if
tempdir
:
# Clean up temp dir
shutil
.
rmtree
(
tempdir
)
return
model
class
TransfoXLModel
(
TransfoXLPreTrainedModel
):
""" Transformer XL model
From "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
by Zihang Dai*, Zhilin Yang*, Yiming Yang, William W. Cohen, Jaime Carbonell,
Quoc V. Le, Ruslan Salakhutdinov (*: equal contribution)
Params:
config: a TransfoXLConfig class instance with the configuration to build a new model
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
with the position indices (selected in the range [config.vocab_size + config.n_special, config.vocab_size + config.n_special + config.n_ctx - 1[.
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
You can use it to add a third embedding (the previous two being the word and position embeddings)
to each token in the sentence.
Outputs:
`hidden_states`: the encoded-hidden-states at the top of the model
as a torch.FloatTensor of size [batch_size, sequence_length, hidden_size]
(or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
Example usage:
```python
# Already been converted into BPE token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
config = modeling_transfo_xl.TransfoXLConfig()
model = modeling_transfo_xl.TransfoXLModel(config)
hidden_states = model(input_ids)
```
"""
def
__init__
(
self
,
config
):
super
(
TransfoXLModel
,
self
).
__init__
(
config
)
self
.
transformer
=
MemTransformerLM
(
**
config
.
to_dict
())
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
):
return
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
)
pytorch_pretrained_bert/tokenization_transfo_xl.py
View file @
8831c688
...
@@ -444,6 +444,12 @@ class TransfoXLCorpus(object):
...
@@ -444,6 +444,12 @@ class TransfoXLCorpus(object):
for
key
,
value
in
corpus_dict
.
items
():
for
key
,
value
in
corpus_dict
.
items
():
corpus
.
__dict__
[
key
]
=
value
corpus
.
__dict__
[
key
]
=
value
corpus
.
vocab
=
vocab
corpus
.
vocab
=
vocab
if
corpus
.
train
is
not
None
:
corpus
.
train
=
torch
.
tensor
(
corpus
.
train
,
dtype
=
torch
.
long
)
if
corpus
.
valid
is
not
None
:
corpus
.
valid
=
torch
.
tensor
(
corpus
.
valid
,
dtype
=
torch
.
long
)
if
corpus
.
test
is
not
None
:
corpus
.
test
=
torch
.
tensor
(
corpus
.
test
,
dtype
=
torch
.
long
)
return
corpus
return
corpus
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment