Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ffd62382
Commit
ffd62382
authored
Feb 17, 2019
by
thomwolf
Browse files
adding gpt2
parent
3a2f97db
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1246 additions
and
4 deletions
+1246
-4
pytorch_pretrained_bert/__init__.py
pytorch_pretrained_bert/__init__.py
+3
-0
pytorch_pretrained_bert/__main__.py
pytorch_pretrained_bert/__main__.py
+22
-4
pytorch_pretrained_bert/convert_gpt2_checkpoint_to_pytorch.py
...rch_pretrained_bert/convert_gpt2_checkpoint_to_pytorch.py
+72
-0
pytorch_pretrained_bert/modeling_gpt2.py
pytorch_pretrained_bert/modeling_gpt2.py
+681
-0
pytorch_pretrained_bert/tokenization_gpt2.py
pytorch_pretrained_bert/tokenization_gpt2.py
+199
-0
tests/modeling_gpt2_test.py
tests/modeling_gpt2_test.py
+213
-0
tests/tokenization_gpt2_test.py
tests/tokenization_gpt2_test.py
+56
-0
No files found.
pytorch_pretrained_bert/__init__.py
View file @
ffd62382
...
@@ -13,6 +13,9 @@ from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
...
@@ -13,6 +13,9 @@ from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
load_tf_weights_in_openai_gpt
)
load_tf_weights_in_openai_gpt
)
from
.modeling_transfo_xl
import
(
TransfoXLConfig
,
TransfoXLModel
,
TransfoXLLMHeadModel
,
from
.modeling_transfo_xl
import
(
TransfoXLConfig
,
TransfoXLModel
,
TransfoXLLMHeadModel
,
load_tf_weights_in_transfo_xl
)
load_tf_weights_in_transfo_xl
)
from
.modeling_gpt2
import
(
GPT2Config
,
GPT2Model
,
GPT2LMHeadModel
,
GPT2DoubleHeadsModel
,
load_tf_weights_in_gpt2
)
from
.optimization
import
BertAdam
from
.optimization
import
BertAdam
from
.optimization_openai
import
OpenAIAdam
from
.optimization_openai
import
OpenAIAdam
...
...
pytorch_pretrained_bert/__main__.py
View file @
ffd62382
...
@@ -4,13 +4,15 @@ def main():
...
@@ -4,13 +4,15 @@ def main():
if
(
len
(
sys
.
argv
)
!=
4
and
len
(
sys
.
argv
)
!=
5
)
or
sys
.
argv
[
1
]
not
in
[
if
(
len
(
sys
.
argv
)
!=
4
and
len
(
sys
.
argv
)
!=
5
)
or
sys
.
argv
[
1
]
not
in
[
"convert_tf_checkpoint_to_pytorch"
,
"convert_tf_checkpoint_to_pytorch"
,
"convert_openai_checkpoint"
,
"convert_openai_checkpoint"
,
"convert_transfo_xl_checkpoint"
"convert_transfo_xl_checkpoint"
,
"convert_gpt2_checkpoint"
,
]:
]:
print
(
print
(
"Should be used as one of:
\n
"
"Should be used as one of:
\n
"
">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`,
\n
"
">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`,
\n
"
">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]` or
\n
"
">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`,
\n
"
">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]`"
)
">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or
\n
"
">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`"
)
else
:
else
:
if
sys
.
argv
[
1
]
==
"convert_tf_checkpoint_to_pytorch"
:
if
sys
.
argv
[
1
]
==
"convert_tf_checkpoint_to_pytorch"
:
try
:
try
:
...
@@ -40,7 +42,7 @@ def main():
...
@@ -40,7 +42,7 @@ def main():
convert_openai_checkpoint_to_pytorch
(
OPENAI_GPT_CHECKPOINT_FOLDER_PATH
,
convert_openai_checkpoint_to_pytorch
(
OPENAI_GPT_CHECKPOINT_FOLDER_PATH
,
OPENAI_GPT_CONFIG
,
OPENAI_GPT_CONFIG
,
PYTORCH_DUMP_OUTPUT
)
PYTORCH_DUMP_OUTPUT
)
el
se
:
el
if
sys
.
argv
[
1
]
==
"convert_transfo_xl_checkpoint"
:
try
:
try
:
from
.convert_transfo_xl_checkpoint_to_pytorch
import
convert_transfo_xl_checkpoint_to_pytorch
from
.convert_transfo_xl_checkpoint_to_pytorch
import
convert_transfo_xl_checkpoint_to_pytorch
except
ImportError
:
except
ImportError
:
...
@@ -61,5 +63,21 @@ def main():
...
@@ -61,5 +63,21 @@ def main():
else
:
else
:
TF_CONFIG
=
""
TF_CONFIG
=
""
convert_transfo_xl_checkpoint_to_pytorch
(
TF_CHECKPOINT
,
TF_CONFIG
,
PYTORCH_DUMP_OUTPUT
,
TF_DATASET_FILE
)
convert_transfo_xl_checkpoint_to_pytorch
(
TF_CHECKPOINT
,
TF_CONFIG
,
PYTORCH_DUMP_OUTPUT
,
TF_DATASET_FILE
)
else
:
try
:
from
.convert_gpt2_checkpoint_to_pytorch
import
convert_gpt2_checkpoint_to_pytorch
except
ImportError
:
print
(
"pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
TF_CHECKPOINT
=
sys
.
argv
[
2
]
PYTORCH_DUMP_OUTPUT
=
sys
.
argv
[
3
]
if
len
(
sys
.
argv
)
==
5
:
TF_CONFIG
=
sys
.
argv
[
4
]
else
:
TF_CONFIG
=
""
convert_gpt2_checkpoint_to_pytorch
(
TF_CHECKPOINT
,
TF_CONFIG
,
PYTORCH_DUMP_OUTPUT
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
main
()
main
()
pytorch_pretrained_bert/convert_gpt2_checkpoint_to_pytorch.py
0 → 100755
View file @
ffd62382
# coding=utf-8
# Copyright 2018 The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert OpenAI GPT checkpoint."""
from
__future__
import
absolute_import
,
division
,
print_function
import
argparse
from
io
import
open
import
torch
from
pytorch_pretrained_bert.modeling_gpt2
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
GPT2Config
,
GPT2Model
,
load_tf_weights_in_gpt2
)
def
convert_gpt2_checkpoint_to_pytorch
(
gpt2_checkpoint_path
,
gpt2_config_file
,
pytorch_dump_folder_path
):
# Construct model
if
gpt2_config_file
==
""
:
config
=
GPT2Config
()
else
:
config
=
GPT2Config
(
gpt2_config_file
)
model
=
GPT2Model
(
config
)
# Load weights from numpy
load_tf_weights_in_gpt2
(
model
,
gpt2_checkpoint_path
)
# Save pytorch-model
pytorch_weights_dump_path
=
pytorch_dump_folder_path
+
'/'
+
WEIGHTS_NAME
pytorch_config_dump_path
=
pytorch_dump_folder_path
+
'/'
+
CONFIG_NAME
print
(
"Save PyTorch model to {}"
.
format
(
pytorch_weights_dump_path
))
torch
.
save
(
model
.
state_dict
(),
pytorch_weights_dump_path
)
print
(
"Save configuration file to {}"
.
format
(
pytorch_config_dump_path
))
with
open
(
pytorch_config_dump_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
config
.
to_json_string
())
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--gpt2_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path the TensorFlow checkpoint path."
)
parser
.
add_argument
(
"--pytorch_dump_folder_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
parser
.
add_argument
(
"--gpt2_config_file"
,
default
=
""
,
type
=
str
,
help
=
"An optional config json file corresponding to the pre-trained OpenAI model.
\n
"
"This specifies the model architecture."
)
args
=
parser
.
parse_args
()
convert_gpt2_checkpoint_to_pytorch
(
args
.
gpt2_checkpoint_path
,
args
.
gpt2_config_file
,
args
.
pytorch_dump_folder_path
)
pytorch_pretrained_bert/modeling_gpt2.py
0 → 100644
View file @
ffd62382
# coding=utf-8
# Copyright 2018 The OpenAI Team Authors and HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch OpenAI GPT-2 model."""
import
collections
import
copy
import
json
import
logging
import
math
import
os
import
shutil
import
tarfile
import
tempfile
import
sys
from
io
import
open
import
torch
import
torch.nn
as
nn
from
torch.nn
import
CrossEntropyLoss
from
torch.nn.parameter
import
Parameter
from
.file_utils
import
cached_path
from
.modeling
import
BertLayerNorm
as
LayerNorm
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_MODEL_ARCHIVE_MAP
=
{
"gpt2"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin"
}
PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
"gpt2"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json"
}
CONFIG_NAME
=
"config.json"
WEIGHTS_NAME
=
"pytorch_model.bin"
def
load_tf_weights_in_gpt2
(
model
,
gpt2_checkpoint_path
):
""" Load tf checkpoints in a pytorch model
"""
try
:
import
re
import
numpy
as
np
import
tensorflow
as
tf
except
ImportError
:
print
(
"Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
tf_path
=
os
.
path
.
abspath
(
gpt2_checkpoint_path
)
print
(
"Converting TensorFlow checkpoint from {}"
.
format
(
tf_path
))
# Load weights from TF model
init_vars
=
tf
.
train
.
list_variables
(
tf_path
)
names
=
[]
arrays
=
[]
for
name
,
shape
in
init_vars
:
print
(
"Loading TF weight {} with shape {}"
.
format
(
name
,
shape
))
array
=
tf
.
train
.
load_variable
(
tf_path
,
name
)
names
.
append
(
name
)
arrays
.
append
(
array
)
for
name
,
array
in
zip
(
names
,
arrays
):
name
=
name
.
split
(
'/'
)
pointer
=
model
for
m_name
in
name
:
if
re
.
fullmatch
(
r
'[A-Za-z]+_\d+'
,
m_name
):
l
=
re
.
split
(
r
'_(\d+)'
,
m_name
)
else
:
l
=
[
m_name
]
if
l
[
0
]
==
'w'
or
l
[
0
]
==
'g'
:
pointer
=
getattr
(
pointer
,
'weight'
)
elif
l
[
0
]
==
'b'
:
pointer
=
getattr
(
pointer
,
'bias'
)
else
:
pointer
=
getattr
(
pointer
,
l
[
0
])
if
len
(
l
)
>=
2
:
num
=
int
(
l
[
1
])
pointer
=
pointer
[
num
]
if
m_name
[
-
11
:]
==
'_embeddings'
:
pointer
=
getattr
(
pointer
,
'weight'
)
elif
m_name
==
'kernel'
:
array
=
np
.
transpose
(
array
)
try
:
assert
pointer
.
shape
==
array
.
shape
except
AssertionError
as
e
:
e
.
args
+=
(
pointer
.
shape
,
array
.
shape
)
raise
print
(
"Initialize PyTorch weight {}"
.
format
(
name
))
pointer
.
data
=
torch
.
from_numpy
(
array
)
return
model
def
gelu
(
x
):
return
0.5
*
x
*
(
1
+
torch
.
tanh
(
math
.
sqrt
(
2
/
math
.
pi
)
*
(
x
+
0.044715
*
torch
.
pow
(
x
,
3
))))
class
GPT2Config
(
object
):
"""Configuration class to store the configuration of a `GPT2Model`.
"""
def
__init__
(
self
,
vocab_size_or_config_json_file
=
40478
,
n_positions
=
1024
,
n_ctx
=
1024
,
n_embd
=
768
,
n_layer
=
12
,
n_head
=
12
,
layer_norm_epsilon
=
1e-5
,
initializer_range
=
0.02
,
):
"""Constructs GPT2Config.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file.
n_positions: Number of positional embeddings.
n_ctx: Size of the causal mask (usually same as n_positions).
n_embd: Dimensionality of the embeddings and hidden states.
n_layer: Number of hidden layers in the Transformer encoder.
n_head: Number of attention heads for each attention layer in
the Transformer encoder.
layer_norm_epsilon: epsilon to use in the layer norm layers
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
"utf-8"
)
as
reader
:
json_config
=
json
.
loads
(
reader
.
read
())
for
key
,
value
in
json_config
.
items
():
self
.
__dict__
[
key
]
=
value
elif
isinstance
(
vocab_size_or_config_json_file
,
int
):
self
.
vocab_size
=
vocab_size_or_config_json_file
self
.
n_ctx
=
n_ctx
self
.
n_positions
=
n_positions
self
.
n_embd
=
n_embd
self
.
n_layer
=
n_layer
self
.
n_head
=
n_head
self
.
layer_norm_epsilon
=
layer_norm_epsilon
self
.
initializer_range
=
initializer_range
else
:
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
@
classmethod
def
from_dict
(
cls
,
json_object
):
"""Constructs a `GPT2Config` from a Python dictionary of parameters."""
config
=
GPT2Config
(
vocab_size_or_config_json_file
=-
1
)
for
key
,
value
in
json_object
.
items
():
config
.
__dict__
[
key
]
=
value
return
config
@
classmethod
def
from_json_file
(
cls
,
json_file
):
"""Constructs a `GPT2Config` from a json file of parameters."""
with
open
(
json_file
,
"r"
,
encoding
=
"utf-8"
)
as
reader
:
text
=
reader
.
read
()
return
cls
.
from_dict
(
json
.
loads
(
text
))
def
__repr__
(
self
):
return
str
(
self
.
to_json_string
())
def
to_dict
(
self
):
"""Serializes this instance to a Python dictionary."""
output
=
copy
.
deepcopy
(
self
.
__dict__
)
return
output
def
to_json_string
(
self
):
"""Serializes this instance to a JSON string."""
return
json
.
dumps
(
self
.
to_dict
(),
indent
=
2
,
sort_keys
=
True
)
+
"
\n
"
class
Conv1D
(
nn
.
Module
):
def
__init__
(
self
,
nf
,
nx
):
super
(
Conv1D
,
self
).
__init__
()
self
.
nf
=
nf
w
=
torch
.
empty
(
nx
,
nf
)
nn
.
init
.
normal_
(
w
,
std
=
0.02
)
self
.
weight
=
Parameter
(
w
)
self
.
bias
=
Parameter
(
torch
.
zeros
(
nf
))
def
forward
(
self
,
x
):
size_out
=
x
.
size
()[:
-
1
]
+
(
self
.
nf
,)
x
=
torch
.
addmm
(
self
.
bias
,
x
.
view
(
-
1
,
x
.
size
(
-
1
)),
self
.
weight
)
x
=
x
.
view
(
*
size_out
)
return
x
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
nx
,
n_ctx
,
config
,
scale
=
False
):
super
(
Attention
,
self
).
__init__
()
n_state
=
nx
# in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
assert
n_state
%
config
.
n_head
==
0
self
.
register_buffer
(
"bias"
,
torch
.
tril
(
torch
.
ones
(
n_ctx
,
n_ctx
)).
view
(
1
,
1
,
n_ctx
,
n_ctx
))
self
.
n_head
=
config
.
n_head
self
.
split_size
=
n_state
self
.
scale
=
scale
self
.
c_attn
=
Conv1D
(
n_state
*
3
,
nx
)
self
.
c_proj
=
Conv1D
(
n_state
,
nx
)
def
_attn
(
self
,
q
,
k
,
v
):
w
=
torch
.
matmul
(
q
,
k
)
if
self
.
scale
:
w
=
w
/
math
.
sqrt
(
v
.
size
(
-
1
))
# w = w * self.bias + -1e9 * (1 - self.bias) # TF implem method: mask_attn_weights
# XD: self.b may be larger than w, so we need to crop it
b
=
self
.
bias
[:,
:,
:
w
.
size
(
-
2
),
:
w
.
size
(
-
1
)]
w
=
w
*
b
+
-
1e10
*
(
1
-
b
)
w
=
nn
.
Softmax
(
dim
=-
1
)(
w
)
return
torch
.
matmul
(
w
,
v
)
def
merge_heads
(
self
,
x
):
x
=
x
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
new_x_shape
=
x
.
size
()[:
-
2
]
+
(
x
.
size
(
-
2
)
*
x
.
size
(
-
1
),)
return
x
.
view
(
*
new_x_shape
)
# in Tensorflow implem: fct merge_states
def
split_heads
(
self
,
x
,
k
=
False
):
new_x_shape
=
x
.
size
()[:
-
1
]
+
(
self
.
n_head
,
x
.
size
(
-
1
)
//
self
.
n_head
)
x
=
x
.
view
(
*
new_x_shape
)
# in Tensorflow implem: fct split_states
if
k
:
return
x
.
permute
(
0
,
2
,
3
,
1
)
else
:
return
x
.
permute
(
0
,
2
,
1
,
3
)
def
forward
(
self
,
x
,
past
=
None
):
x
=
self
.
c_attn
(
x
)
query
,
key
,
value
=
x
.
split
(
self
.
split_size
,
dim
=
2
)
query
=
self
.
split_heads
(
query
)
key
=
self
.
split_heads
(
key
,
k
=
True
)
value
=
self
.
split_heads
(
value
)
present
=
key
,
value
if
past
is
not
None
:
past_key
,
past_value
=
past
key
=
torch
.
cat
((
past_key
,
key
),
dim
=-
2
)
value
=
torch
.
cat
((
past_value
,
value
),
dim
=-
2
)
a
=
self
.
_attn
(
query
,
key
,
value
)
a
=
self
.
merge_heads
(
a
)
a
=
self
.
c_proj
(
a
)
return
a
,
present
class
MLP
(
nn
.
Module
):
def
__init__
(
self
,
n_state
,
config
):
# in MLP: n_state=3072 (4 * n_embd)
super
(
MLP
,
self
).
__init__
()
nx
=
config
.
n_embd
self
.
c_fc
=
Conv1D
(
n_state
,
nx
)
self
.
c_proj
=
Conv1D
(
nx
,
n_state
)
self
.
act
=
gelu
def
forward
(
self
,
x
):
h
=
self
.
act
(
self
.
c_fc
(
x
))
h2
=
self
.
c_proj
(
h
)
return
h2
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
n_ctx
,
config
,
scale
=
False
):
super
(
Block
,
self
).
__init__
()
nx
=
config
.
n_embd
self
.
ln_1
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
Attention
(
nx
,
n_ctx
,
config
,
scale
)
self
.
ln_2
=
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
MLP
(
4
*
nx
,
config
)
def
forward
(
self
,
x
,
past
):
a
,
present
=
self
.
attn
(
self
.
ln_1
(
x
),
past
=
past
)
x
=
x
+
a
m
=
self
.
mlp
(
self
.
ln_2
(
c
))
x
=
x
+
m
return
x
,
present
class
GPT2LMHead
(
nn
.
Module
):
""" Language Model Head for the transformer """
def
__init__
(
self
,
model_embeddings_weights
,
config
):
super
(
GPT2LMHead
,
self
).
__init__
()
self
.
n_embd
=
config
.
n_embd
self
.
set_embeddings_weights
(
model_embeddings_weights
)
def
set_embeddings_weights
(
self
,
model_embeddings_weights
):
embed_shape
=
model_embeddings_weights
.
shape
self
.
decoder
=
nn
.
Linear
(
embed_shape
[
1
],
embed_shape
[
0
],
bias
=
False
)
self
.
decoder
.
weight
=
model_embeddings_weights
# Tied weights
def
forward
(
self
,
hidden_state
):
# Truncated Language modeling logits (we remove the last token)
# h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
lm_logits
=
self
.
decoder
(
hidden_state
)
return
lm_logits
class
GPT2MultipleChoiceHead
(
nn
.
Module
):
""" Classifier Head for the transformer """
def
__init__
(
self
,
config
):
super
(
GPT2MultipleChoiceHead
,
self
).
__init__
()
self
.
n_embd
=
config
.
n_embd
self
.
linear
=
nn
.
Linear
(
config
.
n_embd
,
1
)
nn
.
init
.
normal_
(
self
.
linear
.
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
linear
.
bias
,
0
)
def
forward
(
self
,
hidden_states
,
mc_token_ids
):
# Classification logits
# hidden_state (bsz, num_choices, seq_length, hidden_size)
# mc_token_ids (bsz, num_choices)
mc_token_ids
=
mc_token_ids
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
-
1
,
hidden_states
.
size
(
-
1
))
# (bsz, num_choices, 1, hidden_size)
multiple_choice_h
=
hidden_states
.
gather
(
2
,
mc_token_ids
).
squeeze
(
2
)
# (bsz, num_choices, hidden_size)
multiple_choice_logits
=
self
.
linear
(
multiple_choice_h
).
squeeze
(
-
1
)
# (bsz, num_choices)
return
multiple_choice_logits
class
GPT2PreTrainedModel
(
nn
.
Module
):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
GPT2PreTrainedModel
,
self
).
__init__
()
if
not
isinstance
(
config
,
GPT2Config
):
raise
ValueError
(
"Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
)
)
self
.
config
=
config
def
set_tied
():
pass
def
init_weights
(
self
,
module
):
""" Initialize the weights.
"""
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Embedding
)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
elif
isinstance
(
module
,
LayerNorm
):
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
if
isinstance
(
module
,
nn
.
Linear
)
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
state_dict
=
None
,
cache_dir
=
None
,
from_tf
=
False
,
*
inputs
,
**
kwargs
):
"""
Instantiate a GPT2PreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `openai-gpt`
- a path or url to a pretrained model archive containing:
. `gpt2_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a GPT2Model instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. a TensorFlow checkpoint with trained weights
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
"""
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
archive_file
=
PRETRAINED_MODEL_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
config_file
=
PRETRAINED_CONFIG_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CONFIG_NAME
)
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
", "
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
archive_file
,
config_file
)
)
return
None
if
resolved_archive_file
==
archive_file
and
resolved_config_file
==
config_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
else
:
logger
.
info
(
"loading weights file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
config_file
,
resolved_config_file
))
# Load config
config
=
GPT2Config
.
from_json_file
(
resolved_config_file
)
logger
.
info
(
"Model config {}"
.
format
(
config
))
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
if
state_dict
is
None
and
not
from_tf
:
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
if
not
torch
.
cuda
.
is_available
()
else
None
)
if
from_tf
:
# Directly load from a TensorFlow checkpoint (stored as NumPy array)
return
load_tf_weights_in_gpt2
(
model
,
resolved_archive_file
)
old_keys
=
[]
new_keys
=
[]
for
key
in
state_dict
.
keys
():
new_key
=
None
if
key
.
endswith
(
".g"
):
new_key
=
key
[:
-
2
]
+
".weight"
elif
key
.
endswith
(
".b"
):
new_key
=
key
[:
-
2
]
+
".bias"
elif
key
.
endswith
(
".w"
):
new_key
=
key
[:
-
2
]
+
".weight"
if
new_key
:
old_keys
.
append
(
key
)
new_keys
.
append
(
new_key
)
for
old_key
,
new_key
in
zip
(
old_keys
,
new_keys
):
state_dict
[
new_key
]
=
state_dict
.
pop
(
old_key
)
missing_keys
=
[]
unexpected_keys
=
[]
error_msgs
=
[]
# copy state_dict so _load_from_state_dict can modify it
metadata
=
getattr
(
state_dict
,
"_metadata"
,
None
)
state_dict
=
state_dict
.
copy
()
if
metadata
is
not
None
:
state_dict
.
_metadata
=
metadata
def
load
(
module
,
prefix
=
""
):
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
module
.
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
True
,
missing_keys
,
unexpected_keys
,
error_msgs
)
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
load
(
child
,
prefix
+
name
+
"."
)
start_model
=
model
if
hasattr
(
model
,
"transformer"
)
and
all
(
not
s
.
startswith
(
'transformer.'
)
for
s
in
state_dict
.
keys
()):
start_model
=
model
.
transformer
load
(
start_model
,
prefix
=
""
)
if
len
(
missing_keys
)
>
0
:
logger
.
info
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
model
.
__class__
.
__name__
,
missing_keys
)
)
if
len
(
unexpected_keys
)
>
0
:
logger
.
info
(
"Weights from pretrained model not used in {}: {}"
.
format
(
model
.
__class__
.
__name__
,
unexpected_keys
)
)
if
len
(
error_msgs
)
>
0
:
raise
RuntimeError
(
"Error(s) in loading state_dict for {}:
\n\t
{}"
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
))
)
# Make sure we are still sharing the output and input embeddings after loading weights
model
.
set_tied
()
return
model
class
GPT2Model
(
GPT2PreTrainedModel
):
"""OpenAI GPT-2 model ("Language Models are Unsupervised Multitask Learners").
Params:
config: a GPT2Config class instance with the configuration to build a new model
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
with the position indices (selected in the range [0, config.n_positions - 1[.
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
You can use it to add a third type of embedding to each input token in the sequence
(the previous two being the word and position embeddings).
The input, position and token_type embeddings are summed inside the Transformer before the first
self-attention block.
Outputs:
`hidden_states`: the encoded-hidden-states at the top of the model
as a torch.FloatTensor of size [batch_size, sequence_length, hidden_size]
(or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
Example usage:
```python
# Already been converted into BPE token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
config = modeling_gpt2.GPT2Config()
model = modeling_gpt2.GPT2Model(config)
hidden_states = model(input_ids)
```
"""
def
__init__
(
self
,
config
):
super
(
GPT2Model
,
self
).
__init__
(
config
)
self
.
wte
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
n_embd
)
self
.
wpe
=
nn
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
)
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
)
self
.
h
=
nn
.
ModuleList
([
copy
.
deepcopy
(
block
)
for
_
in
range
(
config
.
n_layer
)])
self
.
ln_f
=
LayerNorm
(
config
.
n_embd
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
past
=
None
):
past_length
=
0
if
past
is
None
else
past
[
0
][
0
].
size
(
-
2
)
if
position_ids
is
None
:
position_ids
=
torch
.
arange
(
past_length
,
input_ids
.
size
(
-
1
)
+
past_length
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
input_ids
)
input_shape
=
input_ids
.
size
()
input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
position_ids
=
position_ids
.
view
(
-
1
,
position_ids
.
size
(
-
1
))
inputs_embeds
=
self
.
wte
(
input_ids
)
position_embeds
=
self
.
wpe
(
position_ids
)
if
token_type_ids
is
not
None
:
token_type_ids
=
token_type_ids
.
view
(
-
1
,
token_type_ids
.
size
(
-
1
))
token_type_embeds
=
self
.
wte
(
token_type_ids
)
else
:
token_type_embeds
=
0
hidden_states
=
inputs_embeds
+
position_embeds
+
token_type_embeds
presents
=
[]
for
block
in
self
.
h
:
hidden_states
,
present
=
block
(
hidden_states
)
presents
.
append
(
present
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
output_shape
=
input_shape
+
(
hidden_states
.
size
(
-
1
),)
return
hidden_states
.
view
(
*
output_shape
),
presents
class
GPT2LMHeadModel
(
GPT2PreTrainedModel
):
"""OpenAI GPT-2 model with a Language Modeling head ("Language Models are Unsupervised Multitask Learners").
Params:
config: a GPT2Config class instance with the configuration to build a new model
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
with the position indices (selected in the range [0, config.n_positions - 1[.
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
You can use it to add a third type of embedding to each input token in the sequence
(the previous two being the word and position embeddings).
The input, position and token_type embeddings are summed inside the Transformer before the first
self-attention block.
`lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., vocab_size]
Outputs:
if `lm_labels` is not `None`:
Outputs the language modeling loss.
else:
`lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, sequence_length, config.vocab_size]
(or more generally [d_1, ..., d_n, config.vocab_size] were d_1 ... d_n are the dimension of input_ids)
Example usage:
```python
# Already been converted into BPE token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
config = modeling_gpt2.GPT2Config()
model = modeling_gpt2.GPT2LMHeadModel(config)
lm_logits = model(input_ids)
```
"""
def
__init__
(
self
,
config
):
super
(
GPT2LMHeadModel
,
self
).
__init__
(
config
)
self
.
transformer
=
GPT2Model
(
config
)
self
.
lm_head
=
GPT2LMHead
(
self
.
transformer
.
wte
.
weight
,
config
)
self
.
apply
(
self
.
init_weights
)
def
set_tied
(
self
):
""" Make sure we are sharing the embeddings
"""
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
wte
.
weight
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
lm_labels
=
None
,
past
=
None
):
hidden_states
,
presents
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
past
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
if
lm_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
lm_logits
.
view
(
-
1
,
lm_logits
.
size
(
-
1
)),
lm_labels
.
view
(
-
1
))
return
loss
return
lm_logits
,
presents
class
GPT2DoubleHeadsModel
(
GPT2PreTrainedModel
):
"""OpenAI GPT-2 model with a Language Modeling and a Multiple Choice head ("Language Models are Unsupervised Multitask Learners").
Params:
config: a GPT2Config class instance with the configuration to build a new model
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] with the BPE token
indices selected in the range [0, config.vocab_size[
`mc_token_ids`: a torch.LongTensor of shape [batch_size, num_choices] with the index of the token from
which we should take the hidden state to feed the multiple choice classifier (usually last token of the sequence)
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
with the position indices (selected in the range [0, config.n_positions - 1[.
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
You can use it to add a third type of embedding to each input token in the sequence
(the previous two being the word and position embeddings).
The input, position and token_type embeddings are summed inside the Transformer before the first
self-attention block.
`lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, num_choices, sequence_length]
with indices selected in [-1, 0, ..., config.vocab_size]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., config.vocab_size]
`multiple_choice_labels`: optional multiple choice labels: torch.LongTensor of shape [batch_size]
with indices selected in [0, ..., num_choices].
Outputs:
if `lm_labels` and `multiple_choice_labels` are not `None`:
Outputs a tuple of losses with the language modeling loss and the multiple choice loss.
else: a tuple with
`lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, num_choices, sequence_length, config.vocab_size]
`multiple_choice_logits`: the multiple choice logits as a torch.FloatTensor of size [batch_size, num_choices]
Example usage:
```python
# Already been converted into BPE token ids
input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]]]) # (bsz, number of choice, seq length)
mc_token_ids = torch.LongTensor([[2], [1]]) # (bsz, number of choice)
config = modeling_gpt2.GPT2Config()
model = modeling_gpt2.GPT2LMHeadModel(config)
lm_logits, multiple_choice_logits = model(input_ids, mc_token_ids)
```
"""
def
__init__
(
self
,
config
):
super
(
GPT2DoubleHeadsModel
,
self
).
__init__
(
config
)
self
.
transformer
=
GPT2Model
(
config
)
self
.
lm_head
=
GPT2LMHead
(
self
.
transformer
.
wte
.
weight
,
config
)
self
.
multiple_choice_head
=
GPT2MultipleChoiceHead
(
config
)
self
.
apply
(
self
.
init_weights
)
def
set_tied
(
self
):
""" Make sure we are sharing the embeddings
"""
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
wte
.
weight
)
def
forward
(
self
,
input_ids
,
mc_token_ids
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
past
=
None
):
hidden_states
,
presents
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
past
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
)
losses
=
[]
if
lm_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
losses
.
append
(
loss_fct
(
lm_logits
.
view
(
-
1
,
lm_logits
.
size
(
-
1
)),
lm_labels
.
view
(
-
1
)))
if
mc_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
losses
.
append
(
loss_fct
(
mc_logits
.
view
(
-
1
,
mc_logits
.
size
(
-
1
)),
mc_labels
.
view
(
-
1
)))
if
losses
:
return
losses
return
lm_logits
,
mc_logits
,
presents
pytorch_pretrained_bert/tokenization_gpt2.py
0 → 100644
View file @
ffd62382
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from
__future__
import
(
absolute_import
,
division
,
print_function
,
unicode_literals
)
import
json
import
logging
import
os
import
regex
as
re
import
sys
from
io
import
open
from
functools
import
lru_cache
from
tqdm
import
tqdm
from
.file_utils
import
cached_path
from
.tokenization
import
BasicTokenizer
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_VOCAB_ARCHIVE_MAP
=
{
'gpt2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json"
,
}
PRETRAINED_MERGES_ARCHIVE_MAP
=
{
'gpt2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt"
,
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
=
{
'gpt2'
:
1024
,
}
VOCAB_NAME
=
'vocab.json'
MERGES_NAME
=
'merges.txt'
@
lru_cache
()
def
bytes_to_unicode
():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs
=
list
(
range
(
ord
(
"!"
),
ord
(
"~"
)
+
1
))
+
list
(
range
(
ord
(
"¡"
),
ord
(
"¬"
)
+
1
))
+
list
(
range
(
ord
(
"®"
),
ord
(
"ÿ"
)
+
1
))
cs
=
bs
[:]
n
=
0
for
b
in
range
(
2
**
8
):
if
b
not
in
bs
:
bs
.
append
(
b
)
cs
.
append
(
2
**
8
+
n
)
n
+=
1
cs
=
[
chr
(
n
)
for
n
in
cs
]
return
dict
(
zip
(
bs
,
cs
))
def
get_pairs
(
word
):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs
=
set
()
prev_char
=
word
[
0
]
for
char
in
word
[
1
:]:
pairs
.
add
((
prev_char
,
char
))
prev_char
=
char
return
pairs
class
GPT2Tokenizer
(
object
):
"""
GPT-2 BPE tokenizer. Peculiarities:
- Byte-level BPE
- argument special_tokens and function set_special_tokens:
can be used to add additional symbols (ex: "__classify__") to a vocabulary.
"""
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
vocab_file
=
PRETRAINED_VOCAB_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
merges_file
=
PRETRAINED_MERGES_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
vocab_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
VOCAB_NAME
)
merges_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
MERGES_NAME
)
# redirect to the cache, if necessary
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
resolved_merges_file
=
cached_path
(
merges_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
vocab_file
,
merges_file
))
return
None
if
resolved_vocab_file
==
vocab_file
and
resolved_merges_file
==
merges_file
:
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
logger
.
info
(
"loading merges file {}"
.
format
(
merges_file
))
else
:
logger
.
info
(
"loading vocabulary file {} from cache at {}"
.
format
(
vocab_file
,
resolved_vocab_file
))
logger
.
info
(
"loading merges file {} from cache at {}"
.
format
(
merges_file
,
resolved_merges_file
))
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len
=
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
[
pretrained_model_name_or_path
]
kwargs
[
'max_len'
]
=
min
(
kwargs
.
get
(
'max_len'
,
int
(
1e12
)),
max_len
)
# Instantiate tokenizer.
tokenizer
=
cls
(
resolved_vocab_file
,
resolved_merges_file
,
*
inputs
,
**
kwargs
)
return
tokenizer
def
__init__
(
self
,
vocab_file
,
merges_file
,
errors
=
'replace'
,
max_len
=
None
):
self
.
max_len
=
max_len
if
max_len
is
not
None
else
int
(
1e12
)
self
.
encoder
=
json
.
load
(
open
(
vocab_file
))
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
self
.
errors
=
errors
# how to handle errors in decoding
self
.
byte_encoder
=
bytes_to_unicode
()
self
.
byte_decoder
=
{
v
:
k
for
k
,
v
in
self
.
byte_encoder
.
items
()}
bpe_data
=
open
(
merges_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[
1
:
-
1
]
bpe_merges
=
[
tuple
(
merge
.
split
())
for
merge
in
bpe_data
]
self
.
bpe_ranks
=
dict
(
zip
(
bpe_merges
,
range
(
len
(
bpe_merges
))))
self
.
cache
=
{}
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self
.
pat
=
re
.
compile
(
r
"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
def
__len__
(
self
):
return
len
(
self
.
encoder
)
def
bpe
(
self
,
token
):
if
token
in
self
.
cache
:
return
self
.
cache
[
token
]
word
=
tuple
(
token
)
pairs
=
get_pairs
(
word
)
if
not
pairs
:
return
token
while
True
:
bigram
=
min
(
pairs
,
key
=
lambda
pair
:
self
.
bpe_ranks
.
get
(
pair
,
float
(
'inf'
)))
if
bigram
not
in
self
.
bpe_ranks
:
break
first
,
second
=
bigram
new_word
=
[]
i
=
0
while
i
<
len
(
word
):
try
:
j
=
word
.
index
(
first
,
i
)
new_word
.
extend
(
word
[
i
:
j
])
i
=
j
except
:
new_word
.
extend
(
word
[
i
:])
break
if
word
[
i
]
==
first
and
i
<
len
(
word
)
-
1
and
word
[
i
+
1
]
==
second
:
new_word
.
append
(
first
+
second
)
i
+=
2
else
:
new_word
.
append
(
word
[
i
])
i
+=
1
new_word
=
tuple
(
new_word
)
word
=
new_word
if
len
(
word
)
==
1
:
break
else
:
pairs
=
get_pairs
(
word
)
word
=
' '
.
join
(
word
)
self
.
cache
[
token
]
=
word
return
word
def
encode
(
self
,
text
):
bpe_tokens
=
[]
for
token
in
re
.
findall
(
self
.
pat
,
text
):
token
=
''
.
join
(
self
.
byte_encoder
[
b
]
for
b
in
token
.
encode
(
'utf-8'
))
bpe_tokens
.
extend
(
self
.
encoder
[
bpe_token
]
for
bpe_token
in
self
.
bpe
(
token
).
split
(
' '
))
return
bpe_tokens
def
decode
(
self
,
tokens
):
text
=
''
.
join
([
self
.
decoder
[
token
]
for
token
in
tokens
])
text
=
bytearray
([
self
.
byte_decoder
[
c
]
for
c
in
text
]).
decode
(
'utf-8'
,
errors
=
self
.
errors
)
return
text
tests/modeling_gpt2_test.py
0 → 100644
View file @
ffd62382
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
unittest
import
json
import
random
import
torch
from
pytorch_pretrained_bert
import
(
GPT2Config
,
GPT2Model
,
GPT2LMHeadModel
,
GPT2DoubleHeadsModel
)
class
GPT2ModelTest
(
unittest
.
TestCase
):
class
GPT2ModelTester
(
object
):
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
is_training
=
True
,
use_position_ids
=
True
,
use_token_type_ids
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
n_special
=
1
,
n_positions
=
33
,
n_embd
=
32
,
n_layer
=
5
,
n_head
=
4
,
n_choices
=
3
,
type_sequence_label_size
=
2
,
initializer_range
=
0.02
,
num_labels
=
3
,
scope
=
None
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
is_training
=
is_training
self
.
use_position_ids
=
use_position_ids
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
n_special
=
n_special
self
.
n_positions
=
n_positions
self
.
n_embd
=
n_embd
self
.
n_layer
=
n_layer
self
.
n_head
=
n_head
self
.
n_choices
=
n_choices
self
.
type_sequence_label_size
=
type_sequence_label_size
self
.
initializer_range
=
initializer_range
self
.
num_labels
=
num_labels
self
.
scope
=
scope
def
prepare_config_and_inputs
(
self
):
input_ids
=
GPT2ModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
],
self
.
vocab_size
)
position_ids
=
None
if
self
.
use_position_ids
:
position_ids
=
GPT2ModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
],
self
.
n_positions
)
token_type_ids
=
None
if
self
.
use_token_type_ids
:
total_voc
=
self
.
vocab_size
+
self
.
n_special
token_type_ids
=
GPT2ModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
],
total_voc
)
mc_labels
=
None
lm_labels
=
None
mc_token_ids
=
None
if
self
.
use_labels
:
mc_labels
=
GPT2ModelTest
.
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
lm_labels
=
GPT2ModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
],
self
.
num_labels
)
mc_token_ids
=
GPT2ModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
],
self
.
seq_length
)
config
=
GPT2Config
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
n_positions
=
self
.
n_positions
,
n_special
=
self
.
n_special
,
n_embd
=
self
.
n_embd
,
n_layer
=
self
.
n_layer
,
n_head
=
self
.
n_head
,
initializer_range
=
self
.
initializer_range
)
return
(
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
)
def
create_gpt2_model
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
model
=
GPT2Model
(
config
)
model
.
eval
()
hidden_states
,
presents
=
model
(
input_ids
,
position_ids
,
token_type_ids
)
outputs
=
{
"hidden_states"
:
hidden_states
,
"presents"
:
presents
,
}
return
outputs
def
check_gpt2_model_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"hidden_states"
].
size
()),
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
self
.
n_embd
])
def
create_gpt2_lm_head
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
model
=
GPT2LMHeadModel
(
config
)
model
.
eval
()
loss
=
model
(
input_ids
,
position_ids
,
token_type_ids
,
lm_labels
)
lm_logits
,
presents
=
model
(
input_ids
,
position_ids
,
token_type_ids
)
outputs
=
{
"loss"
:
loss
,
"lm_logits"
:
lm_logits
,
"presents"
:
presents
,
}
return
outputs
def
check_gpt2_lm_head_output
(
self
,
result
):
total_voc
=
self
.
n_special
+
self
.
vocab_size
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits"
].
size
()),
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
total_voc
])
def
check_gpt2_lm_head_loss_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
def
create_gpt2_double_heads
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
model
=
GPT2DoubleHeadsModel
(
config
)
model
.
eval
()
loss
=
model
(
input_ids
,
mc_token_ids
,
lm_labels
=
lm_labels
,
mc_labels
=
mc_labels
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
)
lm_logits
,
mc_logits
,
presents
=
model
(
input_ids
,
mc_token_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
)
outputs
=
{
"loss"
:
loss
,
"lm_logits"
:
lm_logits
,
"mc_logits"
:
mc_logits
,
"presents"
:
presents
,
}
return
outputs
def
check_gpt2_double_heads_output
(
self
,
result
):
total_voc
=
self
.
n_special
+
self
.
vocab_size
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits"
].
size
()),
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
total_voc
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"mc_logits"
].
size
()),
[
self
.
batch_size
,
self
.
n_choices
])
def
check_gpt2_double_heads_loss_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
[
list
(
l
.
size
())
for
l
in
result
[
"loss"
]],
[[],
[]])
def
test_default
(
self
):
self
.
run_tester
(
GPT2ModelTest
.
GPT2ModelTester
(
self
))
def
test_config_to_json_string
(
self
):
config
=
GPT2Config
(
vocab_size_or_config_json_file
=
99
,
n_embd
=
37
)
obj
=
json
.
loads
(
config
.
to_json_string
())
self
.
assertEqual
(
obj
[
"vocab_size"
],
99
)
self
.
assertEqual
(
obj
[
"n_embd"
],
37
)
def
run_tester
(
self
,
tester
):
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
output_result
=
tester
.
create_gpt2_model
(
*
config_and_inputs
)
tester
.
check_gpt2_model_output
(
output_result
)
output_result
=
tester
.
create_gpt2_lm_head
(
*
config_and_inputs
)
tester
.
check_gpt2_lm_head_output
(
output_result
)
tester
.
check_gpt2_lm_head_loss_output
(
output_result
)
output_result
=
tester
.
create_gpt2_double_heads
(
*
config_and_inputs
)
tester
.
check_gpt2_double_heads_output
(
output_result
)
tester
.
check_gpt2_double_heads_loss_output
(
output_result
)
@
classmethod
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
"""Creates a random int32 tensor of the shape within the vocab size."""
if
rng
is
None
:
rng
=
random
.
Random
()
total_dims
=
1
for
dim
in
shape
:
total_dims
*=
dim
values
=
[]
for
_
in
range
(
total_dims
):
values
.
append
(
rng
.
randint
(
0
,
vocab_size
-
1
))
return
torch
.
tensor
(
data
=
values
,
dtype
=
torch
.
long
).
view
(
shape
).
contiguous
()
if
__name__
==
"__main__"
:
unittest
.
main
()
tests/tokenization_gpt2_test.py
0 → 100644
View file @
ffd62382
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
os
import
unittest
import
json
from
pytorch_pretrained_bert.tokenization_gpt2
import
GPT2Tokenizer
class
GPT2TokenizationTest
(
unittest
.
TestCase
):
def
test_full_tokenizer
(
self
):
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
vocab
=
[
"l"
,
"o"
,
"w"
,
"e"
,
"r"
,
"s"
,
"t"
,
"i"
,
"d"
,
"n"
,
"w</w>"
,
"r</w>"
,
"t</w>"
,
"lo"
,
"low"
,
"er</w>"
,
"low</w>"
,
"lowest</w>"
,
"newer</w>"
,
"wider</w>"
]
vocab_tokens
=
dict
(
zip
(
vocab
,
range
(
len
(
vocab
))))
merges
=
[
"#version: 0.2"
,
"l o"
,
"lo w"
,
"e r</w>"
,
""
]
with
open
(
"/tmp/openai_tokenizer_vocab_test.json"
,
"w"
)
as
fp
:
json
.
dump
(
vocab_tokens
,
fp
)
vocab_file
=
fp
.
name
with
open
(
"/tmp/openai_tokenizer_merges_test.txt"
,
"w"
)
as
fp
:
fp
.
write
(
"
\n
"
.
join
(
merges
))
merges_file
=
fp
.
name
tokenizer
=
GPT2Tokenizer
(
vocab_file
,
merges_file
)
os
.
remove
(
vocab_file
)
os
.
remove
(
merges_file
)
text
=
"lower"
bpe_tokens
=
[
"low"
,
"er</w>"
]
tokens
=
tokenizer
.
tokenize
(
text
)
self
.
assertListEqual
(
tokens
,
bpe_tokens
)
input_tokens
=
tokens
input_bpe_tokens
=
[
14
,
15
,
20
]
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
if
__name__
==
'__main__'
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment