Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d77dd62f
Commit
d77dd62f
authored
Jan 28, 2019
by
thomwolf
Browse files
directly load from TF checkpoints + code cleanup
parent
9c35c132
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
225 additions
and
178 deletions
+225
-178
pytorch_pretrained_bert/__init__.py
pytorch_pretrained_bert/__init__.py
+6
-0
pytorch_pretrained_bert/convert_openai_checkpoint_to_pytorch.py
...h_pretrained_bert/convert_openai_checkpoint_to_pytorch.py
+23
-35
pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py
pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py
+17
-12
pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py
...etrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py
+49
-45
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+18
-7
pytorch_pretrained_bert/modeling_openai.py
pytorch_pretrained_bert/modeling_openai.py
+97
-76
pytorch_pretrained_bert/modeling_transfo_xl.py
pytorch_pretrained_bert/modeling_transfo_xl.py
+12
-3
pytorch_pretrained_bert/tokenization_openai.py
pytorch_pretrained_bert/tokenization_openai.py
+3
-0
No files found.
pytorch_pretrained_bert/__init__.py
View file @
d77dd62f
...
@@ -2,6 +2,7 @@ __version__ = "0.5.0"
...
@@ -2,6 +2,7 @@ __version__ = "0.5.0"
from
.tokenization
import
BertTokenizer
,
BasicTokenizer
,
WordpieceTokenizer
from
.tokenization
import
BertTokenizer
,
BasicTokenizer
,
WordpieceTokenizer
from
.tokenization_openai
import
OpenAIGPTTokenizer
from
.tokenization_openai
import
OpenAIGPTTokenizer
from
.tokenization_transfo_xl
import
(
TransfoXLTokenizer
,
TransfoXLCorpus
)
from
.tokenization_transfo_xl
import
(
TransfoXLTokenizer
,
TransfoXLCorpus
)
from
.modeling
import
(
BertConfig
,
BertModel
,
BertForPreTraining
,
from
.modeling
import
(
BertConfig
,
BertModel
,
BertForPreTraining
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForSequenceClassification
,
BertForMultipleChoice
,
BertForSequenceClassification
,
BertForMultipleChoice
,
...
@@ -9,6 +10,11 @@ from .modeling import (BertConfig, BertModel, BertForPreTraining,
...
@@ -9,6 +10,11 @@ from .modeling import (BertConfig, BertModel, BertForPreTraining,
from
.modeling_openai
import
(
OpenAIGPTConfig
,
OpenAIGPTModel
,
from
.modeling_openai
import
(
OpenAIGPTConfig
,
OpenAIGPTModel
,
OpenAIGPTLMHeadModel
,
OpenAIGPTDoubleHeadsModel
)
OpenAIGPTLMHeadModel
,
OpenAIGPTDoubleHeadsModel
)
from
.modeling_transfo_xl
import
(
TransfoXLConfig
,
TransfoXLModel
)
from
.modeling_transfo_xl
import
(
TransfoXLConfig
,
TransfoXLModel
)
from
.optimization
import
BertAdam
from
.optimization
import
BertAdam
from
.optimization_openai
import
OpenAIAdam
from
.optimization_openai
import
OpenAIAdam
from
.convert_openai_checkpoint_to_pytorch
import
load_tf_weights_in_openai_gpt
from
.convert_tf_checkpoint_to_pytorch
import
load_tf_weights_in_bert
from
.convert_transfo_xl_checkpoint_to_pytorch
import
load_tf_weights_in_transfo_xl
from
.file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
from
.file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
pytorch_pretrained_bert/convert_openai_checkpoint_to_pytorch.py
View file @
d77dd62f
...
@@ -26,9 +26,29 @@ import numpy as np
...
@@ -26,9 +26,29 @@ import numpy as np
from
.modeling_openai
import
OpenAIGPTConfig
,
OpenAIGPTModel
,
CONFIG_NAME
,
WEIGHTS_NAME
from
.modeling_openai
import
OpenAIGPTConfig
,
OpenAIGPTModel
,
CONFIG_NAME
,
WEIGHTS_NAME
def
convert_openai_checkpoint_to_pytorch
(
openai_checkpoint_folder_path
,
openai_config_file
,
pytorch_dump_folder_path
):
def
convert_openai_checkpoint_to_pytorch
(
openai_checkpoint_folder_path
,
openai_config_file
,
pytorch_dump_folder_path
):
# Load weights from TF model
# Construct model
if
openai_config_file
==
""
:
config
=
OpenAIGPTConfig
()
else
:
config
=
OpenAIGPTConfig
(
openai_config_file
)
model
=
OpenAIGPTModel
(
config
)
# Load weights from numpy
load_tf_weights_in_openai_gpt
(
model
,
openai_checkpoint_folder_path
)
# Save pytorch-model
pytorch_weights_dump_path
=
pytorch_dump_folder_path
+
'/'
+
WEIGHTS_NAME
pytorch_config_dump_path
=
pytorch_dump_folder_path
+
'/'
+
CONFIG_NAME
print
(
"Save PyTorch model to {}"
.
format
(
pytorch_weights_dump_path
))
torch
.
save
(
model
.
state_dict
(),
pytorch_weights_dump_path
)
print
(
"Save configuration file to {}"
.
format
(
pytorch_config_dump_path
))
with
open
(
pytorch_config_dump_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
config
.
to_json_string
())
def
load_tf_weights_in_openai_gpt
(
model
,
openai_checkpoint_folder_path
):
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
"""
print
(
"Loading weights..."
)
print
(
"Loading weights..."
)
names
=
json
.
load
(
open
(
openai_checkpoint_folder_path
+
'/parameters_names.json'
,
"r"
,
encoding
=
'utf-8'
))
names
=
json
.
load
(
open
(
openai_checkpoint_folder_path
+
'/parameters_names.json'
,
"r"
,
encoding
=
'utf-8'
))
shapes
=
json
.
load
(
open
(
openai_checkpoint_folder_path
+
'/params_shapes.json'
,
"r"
,
encoding
=
'utf-8'
))
shapes
=
json
.
load
(
open
(
openai_checkpoint_folder_path
+
'/params_shapes.json'
,
"r"
,
encoding
=
'utf-8'
))
...
@@ -36,35 +56,11 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
...
@@ -36,35 +56,11 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
init_params
=
[
np
.
load
(
openai_checkpoint_folder_path
+
'/params_{}.npy'
.
format
(
n
))
for
n
in
range
(
10
)]
init_params
=
[
np
.
load
(
openai_checkpoint_folder_path
+
'/params_{}.npy'
.
format
(
n
))
for
n
in
range
(
10
)]
init_params
=
np
.
split
(
np
.
concatenate
(
init_params
,
0
),
offsets
)[:
-
1
]
init_params
=
np
.
split
(
np
.
concatenate
(
init_params
,
0
),
offsets
)[:
-
1
]
init_params
=
[
param
.
reshape
(
shape
)
for
param
,
shape
in
zip
(
init_params
,
shapes
)]
init_params
=
[
param
.
reshape
(
shape
)
for
param
,
shape
in
zip
(
init_params
,
shapes
)]
# if n_ctx > 0:
# init_params[0] = init_params[0][:n_ctx]
# if n_special > 0:
# init_params[0] = np.concatenate(
# [init_params[1],
# (np.random.randn(n_special, n_embd) * 0.02).astype(np.float32),
# init_params[0]
# ], 0)
# else:
# init_params[0] = np.concatenate(
# [init_params[1],
# init_params[0]
# ], 0)
# del init_params[1]
# if n_transfer == -1:
# n_transfer = 0
# else:
# n_transfer = 1 + n_transfer * 12
init_params
[
0
]
=
np
.
concatenate
([
init_params
[
1
],
init_params
[
0
]],
0
)
init_params
[
0
]
=
np
.
concatenate
([
init_params
[
1
],
init_params
[
0
]],
0
)
del
init_params
[
1
]
del
init_params
[
1
]
init_params
=
[
arr
.
squeeze
()
for
arr
in
init_params
]
init_params
=
[
arr
.
squeeze
()
for
arr
in
init_params
]
# Construct model
if
openai_config_file
==
""
:
config
=
OpenAIGPTConfig
()
else
:
config
=
OpenAIGPTConfig
(
openai_config_file
)
model
=
OpenAIGPTModel
(
config
)
try
:
try
:
assert
model
.
embed
.
weight
.
shape
==
init_params
[
0
].
shape
assert
model
.
embed
.
weight
.
shape
==
init_params
[
0
].
shape
except
AssertionError
as
e
:
except
AssertionError
as
e
:
...
@@ -109,15 +105,7 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
...
@@ -109,15 +105,7 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
raise
raise
print
(
"Initialize PyTorch weight {}"
.
format
(
name
))
print
(
"Initialize PyTorch weight {}"
.
format
(
name
))
pointer
.
data
=
torch
.
from_numpy
(
array
)
pointer
.
data
=
torch
.
from_numpy
(
array
)
return
model
# Save pytorch-model
pytorch_weights_dump_path
=
pytorch_dump_folder_path
+
'/'
+
WEIGHTS_NAME
pytorch_config_dump_path
=
pytorch_dump_folder_path
+
'/'
+
CONFIG_NAME
print
(
"Save PyTorch model to {}"
.
format
(
pytorch_weights_dump_path
))
torch
.
save
(
model
.
state_dict
(),
pytorch_weights_dump_path
)
print
(
"Save configuration file to {}"
.
format
(
pytorch_config_dump_path
))
with
open
(
pytorch_config_dump_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
config
.
to_json_string
())
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
...
pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py
View file @
d77dd62f
...
@@ -28,9 +28,23 @@ import numpy as np
...
@@ -28,9 +28,23 @@ import numpy as np
from
.modeling
import
BertConfig
,
BertForPreTraining
from
.modeling
import
BertConfig
,
BertForPreTraining
def
convert_tf_checkpoint_to_pytorch
(
tf_checkpoint_path
,
bert_config_file
,
pytorch_dump_path
):
def
convert_tf_checkpoint_to_pytorch
(
tf_checkpoint_path
,
bert_config_file
,
pytorch_dump_path
):
config_path
=
os
.
path
.
abspath
(
bert_config_file
)
# Initialise PyTorch model
config
=
BertConfig
.
from_json_file
(
bert_config_file
)
print
(
"Building PyTorch model from configuration: {}"
.
format
(
str
(
config
)))
model
=
BertForPreTraining
(
config
)
# Load weights from tf checkpoint
load_tf_weights_in_bert
(
model
,
tf_checkpoint_path
)
# Save pytorch-model
print
(
"Save PyTorch model to {}"
.
format
(
pytorch_dump_path
))
torch
.
save
(
model
.
state_dict
(),
pytorch_dump_path
)
def
load_tf_weights_in_bert
(
model
,
tf_checkpoint_path
):
""" Load tf checkpoints in a pytorch model
"""
tf_path
=
os
.
path
.
abspath
(
tf_checkpoint_path
)
tf_path
=
os
.
path
.
abspath
(
tf_checkpoint_path
)
print
(
"Converting TensorFlow checkpoint from {}
with config at {}"
.
format
(
tf_path
,
config
_path
))
print
(
"Converting TensorFlow checkpoint from {}
"
.
format
(
tf
_path
))
# Load weights from TF model
# Load weights from TF model
init_vars
=
tf
.
train
.
list_variables
(
tf_path
)
init_vars
=
tf
.
train
.
list_variables
(
tf_path
)
names
=
[]
names
=
[]
...
@@ -41,11 +55,6 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
...
@@ -41,11 +55,6 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
names
.
append
(
name
)
names
.
append
(
name
)
arrays
.
append
(
array
)
arrays
.
append
(
array
)
# Initialise PyTorch model
config
=
BertConfig
.
from_json_file
(
bert_config_file
)
print
(
"Building PyTorch model from configuration: {}"
.
format
(
str
(
config
)))
model
=
BertForPreTraining
(
config
)
for
name
,
array
in
zip
(
names
,
arrays
):
for
name
,
array
in
zip
(
names
,
arrays
):
name
=
name
.
split
(
'/'
)
name
=
name
.
split
(
'/'
)
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
...
@@ -81,11 +90,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
...
@@ -81,11 +90,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
raise
raise
print
(
"Initialize PyTorch weight {}"
.
format
(
name
))
print
(
"Initialize PyTorch weight {}"
.
format
(
name
))
pointer
.
data
=
torch
.
from_numpy
(
array
)
pointer
.
data
=
torch
.
from_numpy
(
array
)
return
model
# Save pytorch-model
print
(
"Save PyTorch model to {}"
.
format
(
pytorch_dump_path
))
torch
.
save
(
model
.
state_dict
(),
pytorch_dump_path
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
...
pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py
View file @
d77dd62f
...
@@ -106,7 +106,6 @@ def build_tf_to_pytorch_map(model, config):
...
@@ -106,7 +106,6 @@ def build_tf_to_pytorch_map(model, config):
'transformer/r_w_bias'
:
r_w_list
})
'transformer/r_w_bias'
:
r_w_list
})
return
tf_to_pt_map
return
tf_to_pt_map
def
convert_transfo_xl_checkpoint_to_pytorch
(
tf_checkpoint_path
,
def
convert_transfo_xl_checkpoint_to_pytorch
(
tf_checkpoint_path
,
transfo_xl_config_file
,
transfo_xl_config_file
,
pytorch_dump_folder_path
,
pytorch_dump_folder_path
,
...
@@ -140,6 +139,20 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
...
@@ -140,6 +139,20 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
print
(
"Building PyTorch model from configuration: {}"
.
format
(
str
(
config
)))
print
(
"Building PyTorch model from configuration: {}"
.
format
(
str
(
config
)))
model
=
TransfoXLModel
(
config
)
model
=
TransfoXLModel
(
config
)
model
=
load_tf_weights_in_transfo_xl
(
model
,
config
,
tf_path
)
# Save pytorch-model
pytorch_weights_dump_path
=
os
.
path
.
join
(
pytorch_dump_folder_path
,
WEIGHTS_NAME
)
pytorch_config_dump_path
=
os
.
path
.
join
(
pytorch_dump_folder_path
,
CONFIG_NAME
)
print
(
"Save PyTorch model to {}"
.
format
(
os
.
path
.
abspath
(
pytorch_weights_dump_path
)))
torch
.
save
(
model
.
state_dict
(),
pytorch_weights_dump_path
)
print
(
"Save configuration file to {}"
.
format
(
os
.
path
.
abspath
(
pytorch_config_dump_path
)))
with
open
(
pytorch_config_dump_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
config
.
to_json_string
())
def
load_tf_weights_in_transfo_xl
(
model
,
config
,
tf_path
):
""" Load tf checkpoints in a pytorch model
"""
# Build TF to PyTorch weights loading map
# Build TF to PyTorch weights loading map
tf_to_pt_map
=
build_tf_to_pytorch_map
(
model
,
config
)
tf_to_pt_map
=
build_tf_to_pytorch_map
(
model
,
config
)
...
@@ -183,16 +196,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
...
@@ -183,16 +196,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
tf_weights
.
pop
(
name
+
'/Adam_1'
,
None
)
tf_weights
.
pop
(
name
+
'/Adam_1'
,
None
)
print
(
"Weights not copied to PyTorch model: {}"
.
format
(
', '
.
join
(
tf_weights
.
keys
())))
print
(
"Weights not copied to PyTorch model: {}"
.
format
(
', '
.
join
(
tf_weights
.
keys
())))
return
model
# Save pytorch-model
pytorch_weights_dump_path
=
os
.
path
.
join
(
pytorch_dump_folder_path
,
WEIGHTS_NAME
)
pytorch_config_dump_path
=
os
.
path
.
join
(
pytorch_dump_folder_path
,
CONFIG_NAME
)
print
(
"Save PyTorch model to {}"
.
format
(
os
.
path
.
abspath
(
pytorch_weights_dump_path
)))
torch
.
save
(
model
.
state_dict
(),
pytorch_weights_dump_path
)
print
(
"Save configuration file to {}"
.
format
(
os
.
path
.
abspath
(
pytorch_config_dump_path
)))
with
open
(
pytorch_config_dump_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
config
.
to_json_string
())
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
...
pytorch_pretrained_bert/modeling.py
View file @
d77dd62f
...
@@ -33,6 +33,7 @@ from torch import nn
...
@@ -33,6 +33,7 @@ from torch import nn
from
torch.nn
import
CrossEntropyLoss
from
torch.nn
import
CrossEntropyLoss
from
.file_utils
import
cached_path
from
.file_utils
import
cached_path
from
.convert_tf_checkpoint_to_pytorch
import
load_tf_weights_in_bert
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -47,6 +48,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
...
@@ -47,6 +48,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
}
}
CONFIG_NAME
=
'bert_config.json'
CONFIG_NAME
=
'bert_config.json'
WEIGHTS_NAME
=
'pytorch_model.bin'
WEIGHTS_NAME
=
'pytorch_model.bin'
TF_WEIGHTS_NAME
=
'model.ckpt'
def
gelu
(
x
):
def
gelu
(
x
):
"""Implementation of the gelu activation function.
"""Implementation of the gelu activation function.
...
@@ -445,7 +447,8 @@ class BertPreTrainedModel(nn.Module):
...
@@ -445,7 +447,8 @@ class BertPreTrainedModel(nn.Module):
module
.
bias
.
data
.
zero_
()
module
.
bias
.
data
.
zero_
()
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name
,
state_dict
=
None
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name
,
state_dict
=
None
,
cache_dir
=
None
,
from_tf
=
False
,
*
inputs
,
**
kwargs
):
"""
"""
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Download and cache the pre-trained model file if needed.
...
@@ -463,6 +466,10 @@ class BertPreTrainedModel(nn.Module):
...
@@ -463,6 +466,10 @@ class BertPreTrainedModel(nn.Module):
- a path or url to a pretrained model archive containing:
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
*inputs, **kwargs: additional input for the specific Bert class
...
@@ -490,7 +497,7 @@ class BertPreTrainedModel(nn.Module):
...
@@ -490,7 +497,7 @@ class BertPreTrainedModel(nn.Module):
logger
.
info
(
"loading archive file {} from cache at {}"
.
format
(
logger
.
info
(
"loading archive file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
archive_file
,
resolved_archive_file
))
tempdir
=
None
tempdir
=
None
if
os
.
path
.
isdir
(
resolved_archive_file
):
if
os
.
path
.
isdir
(
resolved_archive_file
)
or
from_tf
:
serialization_dir
=
resolved_archive_file
serialization_dir
=
resolved_archive_file
else
:
else
:
# Extract archive to temp dir
# Extract archive to temp dir
...
@@ -506,10 +513,17 @@ class BertPreTrainedModel(nn.Module):
...
@@ -506,10 +513,17 @@ class BertPreTrainedModel(nn.Module):
logger
.
info
(
"Model config {}"
.
format
(
config
))
logger
.
info
(
"Model config {}"
.
format
(
config
))
# Instantiate model.
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
if
state_dict
is
None
:
if
state_dict
is
None
and
not
from_tf
:
weights_path
=
os
.
path
.
join
(
serialization_dir
,
WEIGHTS_NAME
)
weights_path
=
os
.
path
.
join
(
serialization_dir
,
WEIGHTS_NAME
)
state_dict
=
torch
.
load
(
weights_path
)
state_dict
=
torch
.
load
(
weights_path
)
if
tempdir
:
# Clean up temp dir
shutil
.
rmtree
(
tempdir
)
if
from_tf
:
# Directly load from a TensorFlow checkpoint
weights_path
=
os
.
path
.
join
(
serialization_dir
,
TF_WEIGHTS_NAME
)
return
load_tf_weights_in_bert
(
model
,
weights_path
)
# Load from a PyTorch state_dict
old_keys
=
[]
old_keys
=
[]
new_keys
=
[]
new_keys
=
[]
for
key
in
state_dict
.
keys
():
for
key
in
state_dict
.
keys
():
...
@@ -550,9 +564,6 @@ class BertPreTrainedModel(nn.Module):
...
@@ -550,9 +564,6 @@ class BertPreTrainedModel(nn.Module):
if
len
(
error_msgs
)
>
0
:
if
len
(
error_msgs
)
>
0
:
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
if
tempdir
:
# Clean up temp dir
shutil
.
rmtree
(
tempdir
)
return
model
return
model
...
...
pytorch_pretrained_bert/modeling_openai.py
View file @
d77dd62f
...
@@ -32,14 +32,14 @@ from torch.nn.parameter import Parameter
...
@@ -32,14 +32,14 @@ from torch.nn.parameter import Parameter
from
.modeling
import
BertLayerNorm
as
LayerNorm
from
.modeling
import
BertLayerNorm
as
LayerNorm
from
.file_utils
import
cached_path
from
.file_utils
import
cached_path
from
.convert_openai_checkpoint_to_pytorch
import
load_tf_weights_in_openai_gpt
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_MODEL_ARCHIVE_MAP
=
{
PRETRAINED_MODEL_ARCHIVE_MAP
=
{
"openai-gpt"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt.tar.gz"
}
'openai-gpt'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt.tar.gz"
,
CONFIG_NAME
=
"openai_gpt_config.json"
}
WEIGHTS_NAME
=
"pytorch_model.bin"
CONFIG_NAME
=
'openai_gpt_config.json'
WEIGHTS_NAME
=
'pytorch_model.bin'
def
gelu
(
x
):
def
gelu
(
x
):
return
0.5
*
x
*
(
1
+
torch
.
tanh
(
math
.
sqrt
(
2
/
math
.
pi
)
*
(
x
+
0.044715
*
torch
.
pow
(
x
,
3
))))
return
0.5
*
x
*
(
1
+
torch
.
tanh
(
math
.
sqrt
(
2
/
math
.
pi
)
*
(
x
+
0.044715
*
torch
.
pow
(
x
,
3
))))
...
@@ -49,16 +49,15 @@ def swish(x):
...
@@ -49,16 +49,15 @@ def swish(x):
return
x
*
torch
.
sigmoid
(
x
)
return
x
*
torch
.
sigmoid
(
x
)
ACT_FNS
=
{
ACT_FNS
=
{
"relu"
:
nn
.
ReLU
,
"swish"
:
swish
,
"gelu"
:
gelu
}
'relu'
:
nn
.
ReLU
,
'swish'
:
swish
,
'gelu'
:
gelu
}
class
OpenAIGPTConfig
(
object
):
class
OpenAIGPTConfig
(
object
):
"""Configuration class to store the configuration of a `OpenAIGPTModel`.
"""Configuration class to store the configuration of a `OpenAIGPTModel`.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
vocab_size_or_config_json_file
=
40478
,
vocab_size_or_config_json_file
=
40478
,
n_special
=
0
,
n_special
=
0
,
n_ctx
=
512
,
n_ctx
=
512
,
...
@@ -69,7 +68,8 @@ class OpenAIGPTConfig(object):
...
@@ -69,7 +68,8 @@ class OpenAIGPTConfig(object):
resid_pdrop
=
0.1
,
resid_pdrop
=
0.1
,
embd_pdrop
=
0.1
,
embd_pdrop
=
0.1
,
attn_pdrop
=
0.1
,
attn_pdrop
=
0.1
,
initializer_range
=
0.02
):
initializer_range
=
0.02
,
):
"""Constructs OpenAIGPTConfig.
"""Constructs OpenAIGPTConfig.
Args:
Args:
...
@@ -91,7 +91,7 @@ class OpenAIGPTConfig(object):
...
@@ -91,7 +91,7 @@ class OpenAIGPTConfig(object):
initializing all weight matrices.
initializing all weight matrices.
"""
"""
if
isinstance
(
vocab_size_or_config_json_file
,
str
):
if
isinstance
(
vocab_size_or_config_json_file
,
str
):
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
'
utf-8
'
)
as
reader
:
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
"
utf-8
"
)
as
reader
:
json_config
=
json
.
loads
(
reader
.
read
())
json_config
=
json
.
loads
(
reader
.
read
())
for
key
,
value
in
json_config
.
items
():
for
key
,
value
in
json_config
.
items
():
self
.
__dict__
[
key
]
=
value
self
.
__dict__
[
key
]
=
value
...
@@ -108,8 +108,10 @@ class OpenAIGPTConfig(object):
...
@@ -108,8 +108,10 @@ class OpenAIGPTConfig(object):
self
.
attn_pdrop
=
attn_pdrop
self
.
attn_pdrop
=
attn_pdrop
self
.
initializer_range
=
initializer_range
self
.
initializer_range
=
initializer_range
else
:
else
:
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
raise
ValueError
(
"or the path to a pretrained model config file (str)"
)
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
@
property
@
property
def
total_num_embeddings
(
self
):
def
total_num_embeddings
(
self
):
...
@@ -126,7 +128,7 @@ class OpenAIGPTConfig(object):
...
@@ -126,7 +128,7 @@ class OpenAIGPTConfig(object):
@
classmethod
@
classmethod
def
from_json_file
(
cls
,
json_file
):
def
from_json_file
(
cls
,
json_file
):
"""Constructs a `OpenAIGPTConfig` from a json file of parameters."""
"""Constructs a `OpenAIGPTConfig` from a json file of parameters."""
with
open
(
json_file
,
"r"
,
encoding
=
'
utf-8
'
)
as
reader
:
with
open
(
json_file
,
"r"
,
encoding
=
"
utf-8
"
)
as
reader
:
text
=
reader
.
read
()
text
=
reader
.
read
()
return
cls
.
from_dict
(
json
.
loads
(
text
))
return
cls
.
from_dict
(
json
.
loads
(
text
))
...
@@ -142,6 +144,7 @@ class OpenAIGPTConfig(object):
...
@@ -142,6 +144,7 @@ class OpenAIGPTConfig(object):
"""Serializes this instance to a JSON string."""
"""Serializes this instance to a JSON string."""
return
json
.
dumps
(
self
.
to_dict
(),
indent
=
2
,
sort_keys
=
True
)
+
"
\n
"
return
json
.
dumps
(
self
.
to_dict
(),
indent
=
2
,
sort_keys
=
True
)
+
"
\n
"
class
Conv1D
(
nn
.
Module
):
class
Conv1D
(
nn
.
Module
):
def
__init__
(
self
,
nf
,
rf
,
nx
):
def
__init__
(
self
,
nf
,
rf
,
nx
):
super
(
Conv1D
,
self
).
__init__
()
super
(
Conv1D
,
self
).
__init__
()
...
@@ -171,7 +174,7 @@ class Attention(nn.Module):
...
@@ -171,7 +174,7 @@ class Attention(nn.Module):
n_state
=
nx
# in Attention: n_state=768 (nx=n_embd)
n_state
=
nx
# in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
assert
n_state
%
config
.
n_head
==
0
assert
n_state
%
config
.
n_head
==
0
self
.
register_buffer
(
'b'
,
torch
.
tril
(
torch
.
ones
(
n_ctx
,
n_ctx
)).
view
(
1
,
1
,
n_ctx
,
n_ctx
))
self
.
register_buffer
(
"b"
,
torch
.
tril
(
torch
.
ones
(
n_ctx
,
n_ctx
)).
view
(
1
,
1
,
n_ctx
,
n_ctx
))
self
.
n_head
=
config
.
n_head
self
.
n_head
=
config
.
n_head
self
.
split_size
=
n_state
self
.
split_size
=
n_state
self
.
scale
=
scale
self
.
scale
=
scale
...
@@ -186,7 +189,7 @@ class Attention(nn.Module):
...
@@ -186,7 +189,7 @@ class Attention(nn.Module):
w
=
w
/
math
.
sqrt
(
v
.
size
(
-
1
))
w
=
w
/
math
.
sqrt
(
v
.
size
(
-
1
))
# w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights
# w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights
# XD: self.b may be larger than w, so we need to crop it
# XD: self.b may be larger than w, so we need to crop it
b
=
self
.
b
[:,
:,
:
w
.
size
(
-
2
),
:
w
.
size
(
-
1
)]
b
=
self
.
b
[:,
:,
:
w
.
size
(
-
2
),
:
w
.
size
(
-
1
)]
w
=
w
*
b
+
-
1e9
*
(
1
-
b
)
w
=
w
*
b
+
-
1e9
*
(
1
-
b
)
w
=
nn
.
Softmax
(
dim
=-
1
)(
w
)
w
=
nn
.
Softmax
(
dim
=-
1
)(
w
)
...
@@ -281,14 +284,15 @@ class OpenAIGPTMultipleChoiceHead(nn.Module):
...
@@ -281,14 +284,15 @@ class OpenAIGPTMultipleChoiceHead(nn.Module):
self
.
dropout
=
nn
.
Dropout2d
(
config
.
resid_pdrop
)
# To reproduce the noise_shape parameter of TF implementation
self
.
dropout
=
nn
.
Dropout2d
(
config
.
resid_pdrop
)
# To reproduce the noise_shape parameter of TF implementation
self
.
linear
=
nn
.
Linear
(
config
.
n_embd
,
1
)
self
.
linear
=
nn
.
Linear
(
config
.
n_embd
,
1
)
nn
.
init
.
normal_
(
self
.
linear
.
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
linear
.
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
linear
.
bias
,
0
)
nn
.
init
.
normal_
(
self
.
linear
.
bias
,
0
)
def
forward
(
self
,
hidden_states
,
m
ultiple_choice
_token_mask
):
def
forward
(
self
,
hidden_states
,
m
c
_token_mask
):
# Classification logits
# Classification logits
# hidden_states = hidden_states.view(-1, self.n_embd)
# hidden_states = hidden_states.view(-1, self.n_embd)
# multiple_choice_token_mask = multiple_choice_token_mask.view(-1, 1).expand_as(hidden_states)
# mc_token_mask = mc_token_mask.view(-1, 1).expand_as(hidden_states)
multiple_choice_h
=
hidden_states
*
multiple_choice_token_mask
.
unsqueeze
(
-
1
)
mc_token_mask
=
mc_token_mask
.
float
()
multiple_choice_h
=
hidden_states
*
mc_token_mask
.
unsqueeze
(
-
1
)
multiple_choice_h
=
multiple_choice_h
.
sum
(
dim
=-
2
)
multiple_choice_h
=
multiple_choice_h
.
sum
(
dim
=-
2
)
# flat = x[..., 0].contiguous().view(-1)
# flat = x[..., 0].contiguous().view(-1)
# multiple_choice_h = multiple_choice_h[flat == self.multiple_choice_token, :]
# multiple_choice_h = multiple_choice_h[flat == self.multiple_choice_token, :]
...
@@ -307,6 +311,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
...
@@ -307,6 +311,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
""" An abstract class to handle weights initialization and
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
a simple interface for dowloading and loading pretrained models.
"""
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
OpenAIGPTPreTrainedModel
,
self
).
__init__
()
super
(
OpenAIGPTPreTrainedModel
,
self
).
__init__
()
if
not
isinstance
(
config
,
OpenAIGPTConfig
):
if
not
isinstance
(
config
,
OpenAIGPTConfig
):
...
@@ -315,7 +320,8 @@ class OpenAIGPTPreTrainedModel(nn.Module):
...
@@ -315,7 +320,8 @@ class OpenAIGPTPreTrainedModel(nn.Module):
"To create a model from a pretrained model use "
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
))
)
)
self
.
config
=
config
self
.
config
=
config
def
init_weights
(
self
,
module
):
def
init_weights
(
self
,
module
):
...
@@ -335,8 +341,9 @@ class OpenAIGPTPreTrainedModel(nn.Module):
...
@@ -335,8 +341,9 @@ class OpenAIGPTPreTrainedModel(nn.Module):
pass
pass
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name
,
num_special_tokens
=
0
,
state_dict
=
None
,
cache_dir
=
None
,
def
from_pretrained
(
*
inputs
,
**
kwargs
):
cls
,
pretrained_model_name
,
num_special_tokens
=
None
,
state_dict
=
None
,
cache_dir
=
None
,
from_tf
=
False
,
*
inputs
,
**
kwargs
):
"""
"""
Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Download and cache the pre-trained model file if needed.
...
@@ -348,6 +355,10 @@ class OpenAIGPTPreTrainedModel(nn.Module):
...
@@ -348,6 +355,10 @@ class OpenAIGPTPreTrainedModel(nn.Module):
- a path or url to a pretrained model archive containing:
- a path or url to a pretrained model archive containing:
. `openai_gpt_config.json` a configuration file for the model
. `openai_gpt_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a OpenAIGPTModel instance
. `pytorch_model.bin` a PyTorch dump of a OpenAIGPTModel instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. a series of NumPy files containing OpenAI TensorFlow trained weights
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
*inputs, **kwargs: additional input for the specific Bert class
...
@@ -365,24 +376,22 @@ class OpenAIGPTPreTrainedModel(nn.Module):
...
@@ -365,24 +376,22 @@ class OpenAIGPTPreTrainedModel(nn.Module):
"Model name '{}' was not found in model name list ({}). "
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
"associated to this path or url."
.
format
(
pretrained_model_name
,
pretrained_model_name
,
", "
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
archive_file
', '
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
)
archive_file
)
)
)
return
None
return
None
if
resolved_archive_file
==
archive_file
:
if
resolved_archive_file
==
archive_file
:
logger
.
info
(
"loading archive file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading archive file {}"
.
format
(
archive_file
))
else
:
else
:
logger
.
info
(
"loading archive file {} from cache at {}"
.
format
(
logger
.
info
(
"loading archive file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
archive_file
,
resolved_archive_file
))
tempdir
=
None
tempdir
=
None
if
os
.
path
.
isdir
(
resolved_archive_file
):
if
os
.
path
.
isdir
(
resolved_archive_file
):
serialization_dir
=
resolved_archive_file
serialization_dir
=
resolved_archive_file
else
:
else
:
# Extract archive to temp dir
# Extract archive to temp dir
tempdir
=
tempfile
.
mkdtemp
()
tempdir
=
tempfile
.
mkdtemp
()
logger
.
info
(
"extracting archive file {} to temp dir {}"
.
format
(
logger
.
info
(
"extracting archive file {} to temp dir {}"
.
format
(
resolved_archive_file
,
tempdir
))
resolved_archive_file
,
tempdir
))
with
tarfile
.
open
(
resolved_archive_file
,
"r:gz"
)
as
archive
:
with
tarfile
.
open
(
resolved_archive_file
,
'r:gz'
)
as
archive
:
archive
.
extractall
(
tempdir
)
archive
.
extractall
(
tempdir
)
serialization_dir
=
tempdir
serialization_dir
=
tempdir
# Load config
# Load config
...
@@ -391,18 +400,24 @@ class OpenAIGPTPreTrainedModel(nn.Module):
...
@@ -391,18 +400,24 @@ class OpenAIGPTPreTrainedModel(nn.Module):
logger
.
info
(
"Model config {}"
.
format
(
config
))
logger
.
info
(
"Model config {}"
.
format
(
config
))
# Instantiate model.
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
if
state_dict
is
None
:
if
state_dict
is
None
and
not
from_tf
:
weights_path
=
os
.
path
.
join
(
serialization_dir
,
WEIGHTS_NAME
)
weights_path
=
os
.
path
.
join
(
serialization_dir
,
WEIGHTS_NAME
)
state_dict
=
torch
.
load
(
weights_path
)
state_dict
=
torch
.
load
(
weights_path
,
map_location
=
'cpu'
if
not
torch
.
cuda
.
is_available
()
else
None
)
if
tempdir
:
# Clean up temp dir
shutil
.
rmtree
(
tempdir
)
if
from_tf
:
# Directly load from a TensorFlow checkpoint (stored as NumPy array)
return
load_tf_weights_in_openai_gpt
(
model
,
serialization_dir
)
old_keys
=
[]
old_keys
=
[]
new_keys
=
[]
new_keys
=
[]
for
key
in
state_dict
.
keys
():
for
key
in
state_dict
.
keys
():
new_key
=
None
new_key
=
None
if
'
gamma
'
in
key
:
if
"
gamma
"
in
key
:
new_key
=
key
.
replace
(
'
gamma
'
,
'
weight
'
)
new_key
=
key
.
replace
(
"
gamma
"
,
"
weight
"
)
if
'
beta
'
in
key
:
if
"
beta
"
in
key
:
new_key
=
key
.
replace
(
'
beta
'
,
'
bias
'
)
new_key
=
key
.
replace
(
"
beta
"
,
"
bias
"
)
if
new_key
:
if
new_key
:
old_keys
.
append
(
key
)
old_keys
.
append
(
key
)
new_keys
.
append
(
new_key
)
new_keys
.
append
(
new_key
)
...
@@ -413,34 +428,36 @@ class OpenAIGPTPreTrainedModel(nn.Module):
...
@@ -413,34 +428,36 @@ class OpenAIGPTPreTrainedModel(nn.Module):
unexpected_keys
=
[]
unexpected_keys
=
[]
error_msgs
=
[]
error_msgs
=
[]
# copy state_dict so _load_from_state_dict can modify it
# copy state_dict so _load_from_state_dict can modify it
metadata
=
getattr
(
state_dict
,
'
_metadata
'
,
None
)
metadata
=
getattr
(
state_dict
,
"
_metadata
"
,
None
)
state_dict
=
state_dict
.
copy
()
state_dict
=
state_dict
.
copy
()
if
metadata
is
not
None
:
if
metadata
is
not
None
:
state_dict
.
_metadata
=
metadata
state_dict
.
_metadata
=
metadata
def
load
(
module
,
prefix
=
''
):
def
load
(
module
,
prefix
=
""
):
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
module
.
_load_from_state_dict
(
module
.
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
True
,
missing_keys
,
unexpected_keys
,
error_msgs
)
state_dict
,
prefix
,
local_metadata
,
True
,
missing_keys
,
unexpected_keys
,
error_msgs
)
for
name
,
child
in
module
.
_modules
.
items
():
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
if
child
is
not
None
:
load
(
child
,
prefix
+
name
+
'.'
)
load
(
child
,
prefix
+
name
+
"."
)
load
(
model
.
transformer
if
hasattr
(
model
,
'transformer'
)
else
model
,
prefix
=
''
)
load
(
model
.
transformer
if
hasattr
(
model
,
"transformer"
)
else
model
,
prefix
=
""
)
if
len
(
missing_keys
)
>
0
:
if
len
(
missing_keys
)
>
0
:
logger
.
info
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
logger
.
info
(
model
.
__class__
.
__name__
,
missing_keys
))
"Weights of {} not initialized from pretrained model: {}"
.
format
(
model
.
__class__
.
__name__
,
missing_keys
)
)
if
len
(
unexpected_keys
)
>
0
:
if
len
(
unexpected_keys
)
>
0
:
logger
.
info
(
"Weights from pretrained model not used in {}: {}"
.
format
(
logger
.
info
(
model
.
__class__
.
__name__
,
unexpected_keys
))
"Weights from pretrained model not used in {}: {}"
.
format
(
model
.
__class__
.
__name__
,
unexpected_keys
)
)
if
len
(
error_msgs
)
>
0
:
if
len
(
error_msgs
)
>
0
:
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
raise
RuntimeError
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
"Error(s) in loading state_dict for {}:
\n\t
{}"
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
))
)
# Add additional embeddings for special tokens if needed
# Add additional embeddings for special tokens if needed
if
num_special_tokens
!=
config
.
n_special
:
if
num_special_tokens
is
not
None
and
num_special_tokens
!=
config
.
n_special
:
model
.
set_num_special_tokens
(
num_special_tokens
)
model
.
set_num_special_tokens
(
num_special_tokens
)
if
tempdir
:
# Clean up temp dir
shutil
.
rmtree
(
tempdir
)
return
model
return
model
...
@@ -495,6 +512,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
...
@@ -495,6 +512,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
hidden_states = model(input_ids)
hidden_states = model(input_ids)
```
```
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
OpenAIGPTModel
,
self
).
__init__
(
config
)
super
(
OpenAIGPTModel
,
self
).
__init__
(
config
)
total_embeddings_size
=
config
.
vocab_size
+
config
.
n_special
+
config
.
n_ctx
total_embeddings_size
=
config
.
vocab_size
+
config
.
n_special
+
config
.
n_ctx
...
@@ -516,8 +534,8 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
...
@@ -516,8 +534,8 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
# Initialize all new embeddings (in particular the special tokens)
# Initialize all new embeddings (in particular the special tokens)
self
.
init_weights
(
self
.
embed
)
self
.
init_weights
(
self
.
embed
)
# Copy word and positional embeddings from the previous weights
# Copy word and positional embeddings from the previous weights
self
.
embed
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
=
old_embed
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
self
.
embed
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
=
old_embed
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
self
.
embed
.
weight
.
data
[
-
self
.
config
.
n_ctx
:,
:]
=
old_embed
.
weight
.
data
[
-
self
.
config
.
n_ctx
:,
:]
self
.
embed
.
weight
.
data
[
-
self
.
config
.
n_ctx
:,
:]
=
old_embed
.
weight
.
data
[
-
self
.
config
.
n_ctx
:,
:]
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
):
if
position_ids
is
None
:
if
position_ids
is
None
:
...
@@ -544,6 +562,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
...
@@ -544,6 +562,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
hidden_states
=
block
(
hidden_states
)
hidden_states
=
block
(
hidden_states
)
return
hidden_states
.
view
(
*
input_shape
,
hidden_states
.
size
(
-
1
))
return
hidden_states
.
view
(
*
input_shape
,
hidden_states
.
size
(
-
1
))
class
OpenAIGPTLMHeadModel
(
OpenAIGPTPreTrainedModel
):
class
OpenAIGPTLMHeadModel
(
OpenAIGPTPreTrainedModel
):
"""OpenAI GPT model with a Language Modeling head ("Improving Language Understanding by Generative Pre-Training").
"""OpenAI GPT model with a Language Modeling head ("Improving Language Understanding by Generative Pre-Training").
...
@@ -602,6 +621,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
...
@@ -602,6 +621,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
lm_logits = model(input_ids)
lm_logits = model(input_ids)
```
```
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
OpenAIGPTLMHeadModel
,
self
).
__init__
(
config
)
super
(
OpenAIGPTLMHeadModel
,
self
).
__init__
(
config
)
self
.
transformer
=
OpenAIGPTModel
(
config
)
self
.
transformer
=
OpenAIGPTModel
(
config
)
...
@@ -622,6 +642,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
...
@@ -622,6 +642,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
return
loss
return
loss
return
lm_logits
return
lm_logits
class
OpenAIGPTDoubleHeadsModel
(
OpenAIGPTPreTrainedModel
):
class
OpenAIGPTDoubleHeadsModel
(
OpenAIGPTPreTrainedModel
):
"""OpenAI GPT model with a Language Modeling and a Multiple Choice heads ("Improving Language Understanding by Generative Pre-Training").
"""OpenAI GPT model with a Language Modeling and a Multiple Choice heads ("Improving Language Understanding by Generative Pre-Training").
...
@@ -653,7 +674,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
...
@@ -653,7 +674,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
Inputs:
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
with the word BPE token indices selected in the range [0, config.vocab_size[
with the word BPE token indices selected in the range [0, config.vocab_size[
`m
ultiple_choice
_token_mask`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
`m
c
_token_mask`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
with a value of 1 were the last hidden state is (usually the [CLS] token) and 0 otherwise.
with a value of 1 were the last hidden state is (usually the [CLS] token) and 0 otherwise.
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
with the position indices (selected in the range [config.vocab_size + config.n_special,
with the position indices (selected in the range [config.vocab_size + config.n_special,
...
@@ -678,14 +699,15 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
...
@@ -678,14 +699,15 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
```python
```python
# Already been converted into BPE token ids
# Already been converted into BPE token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
m
ultiple_choice
_token_mask = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
m
c
_token_mask = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = modeling_openai.OpenAIGPTConfig()
config = modeling_openai.OpenAIGPTConfig()
model = modeling_openai.OpenAIGPTLMHeadModel(config)
model = modeling_openai.OpenAIGPTLMHeadModel(config)
lm_logits, multiple_choice_logits = model(input_ids, m
ultiple_choice
_token_mask)
lm_logits, multiple_choice_logits = model(input_ids, m
c
_token_mask)
```
```
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
OpenAIGPTDoubleHeadsModel
,
self
).
__init__
(
config
)
super
(
OpenAIGPTDoubleHeadsModel
,
self
).
__init__
(
config
)
self
.
transformer
=
OpenAIGPTModel
(
config
)
self
.
transformer
=
OpenAIGPTModel
(
config
)
...
@@ -698,18 +720,17 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
...
@@ -698,18 +720,17 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
embed
.
weight
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
embed
.
weight
)
def
forward
(
self
,
input_ids
,
multiple_choice_token_mask
,
position_ids
=
None
,
token_type_ids
=
None
,
def
forward
(
self
,
input_ids
,
mc_token_mask
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
position_ids
=
None
):
lm_labels
=
None
,
multiple_choice_labels
=
None
):
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
)
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
m
ultiple_choice
_logits
=
self
.
multiple_choice_head
(
hidden_states
,
m
ultiple_choice
_token_mask
)
m
c
_logits
=
self
.
multiple_choice_head
(
hidden_states
,
m
c
_token_mask
)
losses
=
[]
losses
=
[]
if
lm_labels
is
not
None
:
if
lm_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
losses
.
append
(
loss_fct
(
lm_logits
.
view
(
-
1
,
lm_logits
.
size
(
-
1
)),
lm_labels
.
view
(
-
1
)))
losses
.
append
(
loss_fct
(
lm_logits
.
view
(
-
1
,
lm_logits
.
size
(
-
1
)),
lm_labels
.
view
(
-
1
)))
if
m
ultiple_choice
_labels
is
not
None
:
if
m
c
_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
loss_fct
=
CrossEntropyLoss
()
losses
.
append
(
loss_fct
(
m
ultiple_choice_logits
,
multiple_choice
_labels
.
view
(
-
1
)))
losses
.
append
(
loss_fct
(
m
c_logits
.
view
(
-
1
,
mc_logits
.
size
(
-
1
)),
mc
_labels
.
view
(
-
1
)))
if
losses
:
if
losses
:
return
losses
return
losses
return
lm_logits
,
m
ultiple_choice
_logits
return
lm_logits
,
m
c
_logits
pytorch_pretrained_bert/modeling_transfo_xl.py
View file @
d77dd62f
...
@@ -37,6 +37,7 @@ from torch.nn.parameter import Parameter
...
@@ -37,6 +37,7 @@ from torch.nn.parameter import Parameter
from
.modeling
import
BertLayerNorm
as
LayerNorm
from
.modeling
import
BertLayerNorm
as
LayerNorm
from
.modeling_transfo_xl_utilities
import
ProjectedAdaptiveLogSoftmax
,
sample_logits
from
.modeling_transfo_xl_utilities
import
ProjectedAdaptiveLogSoftmax
,
sample_logits
from
.file_utils
import
cached_path
from
.file_utils
import
cached_path
from
.convert_transfo_xl_checkpoint_to_pytorch
import
load_tf_weights_in_transfo_xl
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -48,6 +49,7 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
...
@@ -48,6 +49,7 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
}
}
CONFIG_NAME
=
'transfo_xl_config.json'
CONFIG_NAME
=
'transfo_xl_config.json'
WEIGHTS_NAME
=
'pytorch_model.bin'
WEIGHTS_NAME
=
'pytorch_model.bin'
TF_WEIGHTS_NAME
=
'model.ckpt'
class
TransfoXLConfig
(
object
):
class
TransfoXLConfig
(
object
):
"""Configuration class to store the configuration of a `TransfoXLModel`.
"""Configuration class to store the configuration of a `TransfoXLModel`.
...
@@ -749,7 +751,7 @@ class TransfoXLPreTrainedModel(nn.Module):
...
@@ -749,7 +751,7 @@ class TransfoXLPreTrainedModel(nn.Module):
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
state_dict
=
None
,
cache_dir
=
None
,
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
state_dict
=
None
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
from_tf
=
False
,
*
inputs
,
**
kwargs
):
"""
"""
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Download and cache the pre-trained model file if needed.
...
@@ -761,6 +763,10 @@ class TransfoXLPreTrainedModel(nn.Module):
...
@@ -761,6 +763,10 @@ class TransfoXLPreTrainedModel(nn.Module):
- a path or url to a pretrained model archive containing:
- a path or url to a pretrained model archive containing:
. `transfo_xl_config.json` a configuration file for the model
. `transfo_xl_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance
. `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
*inputs, **kwargs: additional input for the specific Bert class
...
@@ -799,9 +805,12 @@ class TransfoXLPreTrainedModel(nn.Module):
...
@@ -799,9 +805,12 @@ class TransfoXLPreTrainedModel(nn.Module):
logger
.
info
(
"Model config {}"
.
format
(
config
))
logger
.
info
(
"Model config {}"
.
format
(
config
))
# Instantiate model.
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
if
state_dict
is
None
:
if
state_dict
is
None
and
not
from_tf
:
state_dict
=
torch
.
load
(
resolved_archive_file
)
state_dict
=
torch
.
load
(
resolved_archive_file
)
if
from_tf
:
# Directly load from a TensorFlow checkpoint
weights_path
=
os
.
path
.
join
(
serialization_dir
,
TF_WEIGHTS_NAME
)
return
load_tf_weights_in_transfo_xl
(
model
,
weights_path
)
missing_keys
=
[]
missing_keys
=
[]
unexpected_keys
=
[]
unexpected_keys
=
[]
error_msgs
=
[]
error_msgs
=
[]
...
...
pytorch_pretrained_bert/tokenization_openai.py
View file @
d77dd62f
...
@@ -130,6 +130,9 @@ class OpenAIGPTTokenizer(object):
...
@@ -130,6 +130,9 @@ class OpenAIGPTTokenizer(object):
else
:
else
:
self
.
special_tokens
=
dict
((
tok
,
len
(
self
.
encoder
)
+
i
)
for
i
,
tok
in
enumerate
(
special_tokens
))
self
.
special_tokens
=
dict
((
tok
,
len
(
self
.
encoder
)
+
i
)
for
i
,
tok
in
enumerate
(
special_tokens
))
def
__len__
(
self
):
return
len
(
self
.
encoder
)
+
len
(
self
.
special_tokens
)
def
set_special_tokens
(
self
,
special_tokens
):
def
set_special_tokens
(
self
,
special_tokens
):
self
.
special_tokens
=
dict
((
tok
,
len
(
self
.
encoder
)
+
i
)
for
i
,
tok
in
enumerate
(
special_tokens
))
self
.
special_tokens
=
dict
((
tok
,
len
(
self
.
encoder
)
+
i
)
for
i
,
tok
in
enumerate
(
special_tokens
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment