Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ffd62382
Commit
ffd62382
authored
Feb 17, 2019
by
thomwolf
Browse files
adding gpt2
parent
3a2f97db
Changes
7
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1246 additions
and
4 deletions
+1246
-4
pytorch_pretrained_bert/__init__.py
pytorch_pretrained_bert/__init__.py
+3
-0
pytorch_pretrained_bert/__main__.py
pytorch_pretrained_bert/__main__.py
+22
-4
pytorch_pretrained_bert/convert_gpt2_checkpoint_to_pytorch.py
...rch_pretrained_bert/convert_gpt2_checkpoint_to_pytorch.py
+72
-0
pytorch_pretrained_bert/modeling_gpt2.py
pytorch_pretrained_bert/modeling_gpt2.py
+681
-0
pytorch_pretrained_bert/tokenization_gpt2.py
pytorch_pretrained_bert/tokenization_gpt2.py
+199
-0
tests/modeling_gpt2_test.py
tests/modeling_gpt2_test.py
+213
-0
tests/tokenization_gpt2_test.py
tests/tokenization_gpt2_test.py
+56
-0
No files found.
pytorch_pretrained_bert/__init__.py
View file @
ffd62382
...
@@ -13,6 +13,9 @@ from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
...
@@ -13,6 +13,9 @@ from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
load_tf_weights_in_openai_gpt
)
load_tf_weights_in_openai_gpt
)
from
.modeling_transfo_xl
import
(
TransfoXLConfig
,
TransfoXLModel
,
TransfoXLLMHeadModel
,
from
.modeling_transfo_xl
import
(
TransfoXLConfig
,
TransfoXLModel
,
TransfoXLLMHeadModel
,
load_tf_weights_in_transfo_xl
)
load_tf_weights_in_transfo_xl
)
from
.modeling_gpt2
import
(
GPT2Config
,
GPT2Model
,
GPT2LMHeadModel
,
GPT2DoubleHeadsModel
,
load_tf_weights_in_gpt2
)
from
.optimization
import
BertAdam
from
.optimization
import
BertAdam
from
.optimization_openai
import
OpenAIAdam
from
.optimization_openai
import
OpenAIAdam
...
...
pytorch_pretrained_bert/__main__.py
View file @
ffd62382
...
@@ -4,13 +4,15 @@ def main():
...
@@ -4,13 +4,15 @@ def main():
if
(
len
(
sys
.
argv
)
!=
4
and
len
(
sys
.
argv
)
!=
5
)
or
sys
.
argv
[
1
]
not
in
[
if
(
len
(
sys
.
argv
)
!=
4
and
len
(
sys
.
argv
)
!=
5
)
or
sys
.
argv
[
1
]
not
in
[
"convert_tf_checkpoint_to_pytorch"
,
"convert_tf_checkpoint_to_pytorch"
,
"convert_openai_checkpoint"
,
"convert_openai_checkpoint"
,
"convert_transfo_xl_checkpoint"
"convert_transfo_xl_checkpoint"
,
"convert_gpt2_checkpoint"
,
]:
]:
print
(
print
(
"Should be used as one of:
\n
"
"Should be used as one of:
\n
"
">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`,
\n
"
">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`,
\n
"
">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]` or
\n
"
">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`,
\n
"
">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]`"
)
">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or
\n
"
">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`"
)
else
:
else
:
if
sys
.
argv
[
1
]
==
"convert_tf_checkpoint_to_pytorch"
:
if
sys
.
argv
[
1
]
==
"convert_tf_checkpoint_to_pytorch"
:
try
:
try
:
...
@@ -40,7 +42,7 @@ def main():
...
@@ -40,7 +42,7 @@ def main():
convert_openai_checkpoint_to_pytorch
(
OPENAI_GPT_CHECKPOINT_FOLDER_PATH
,
convert_openai_checkpoint_to_pytorch
(
OPENAI_GPT_CHECKPOINT_FOLDER_PATH
,
OPENAI_GPT_CONFIG
,
OPENAI_GPT_CONFIG
,
PYTORCH_DUMP_OUTPUT
)
PYTORCH_DUMP_OUTPUT
)
el
se
:
el
if
sys
.
argv
[
1
]
==
"convert_transfo_xl_checkpoint"
:
try
:
try
:
from
.convert_transfo_xl_checkpoint_to_pytorch
import
convert_transfo_xl_checkpoint_to_pytorch
from
.convert_transfo_xl_checkpoint_to_pytorch
import
convert_transfo_xl_checkpoint_to_pytorch
except
ImportError
:
except
ImportError
:
...
@@ -61,5 +63,21 @@ def main():
...
@@ -61,5 +63,21 @@ def main():
else
:
else
:
TF_CONFIG
=
""
TF_CONFIG
=
""
convert_transfo_xl_checkpoint_to_pytorch
(
TF_CHECKPOINT
,
TF_CONFIG
,
PYTORCH_DUMP_OUTPUT
,
TF_DATASET_FILE
)
convert_transfo_xl_checkpoint_to_pytorch
(
TF_CHECKPOINT
,
TF_CONFIG
,
PYTORCH_DUMP_OUTPUT
,
TF_DATASET_FILE
)
else
:
try
:
from
.convert_gpt2_checkpoint_to_pytorch
import
convert_gpt2_checkpoint_to_pytorch
except
ImportError
:
print
(
"pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
TF_CHECKPOINT
=
sys
.
argv
[
2
]
PYTORCH_DUMP_OUTPUT
=
sys
.
argv
[
3
]
if
len
(
sys
.
argv
)
==
5
:
TF_CONFIG
=
sys
.
argv
[
4
]
else
:
TF_CONFIG
=
""
convert_gpt2_checkpoint_to_pytorch
(
TF_CHECKPOINT
,
TF_CONFIG
,
PYTORCH_DUMP_OUTPUT
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
main
()
main
()
pytorch_pretrained_bert/convert_gpt2_checkpoint_to_pytorch.py
0 → 100755
View file @
ffd62382
# coding=utf-8
# Copyright 2018 The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert OpenAI GPT checkpoint."""
from
__future__
import
absolute_import
,
division
,
print_function
import
argparse
from
io
import
open
import
torch
from
pytorch_pretrained_bert.modeling_gpt2
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
GPT2Config
,
GPT2Model
,
load_tf_weights_in_gpt2
)
def
convert_gpt2_checkpoint_to_pytorch
(
gpt2_checkpoint_path
,
gpt2_config_file
,
pytorch_dump_folder_path
):
# Construct model
if
gpt2_config_file
==
""
:
config
=
GPT2Config
()
else
:
config
=
GPT2Config
(
gpt2_config_file
)
model
=
GPT2Model
(
config
)
# Load weights from numpy
load_tf_weights_in_gpt2
(
model
,
gpt2_checkpoint_path
)
# Save pytorch-model
pytorch_weights_dump_path
=
pytorch_dump_folder_path
+
'/'
+
WEIGHTS_NAME
pytorch_config_dump_path
=
pytorch_dump_folder_path
+
'/'
+
CONFIG_NAME
print
(
"Save PyTorch model to {}"
.
format
(
pytorch_weights_dump_path
))
torch
.
save
(
model
.
state_dict
(),
pytorch_weights_dump_path
)
print
(
"Save configuration file to {}"
.
format
(
pytorch_config_dump_path
))
with
open
(
pytorch_config_dump_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
config
.
to_json_string
())
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--gpt2_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path the TensorFlow checkpoint path."
)
parser
.
add_argument
(
"--pytorch_dump_folder_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
parser
.
add_argument
(
"--gpt2_config_file"
,
default
=
""
,
type
=
str
,
help
=
"An optional config json file corresponding to the pre-trained OpenAI model.
\n
"
"This specifies the model architecture."
)
args
=
parser
.
parse_args
()
convert_gpt2_checkpoint_to_pytorch
(
args
.
gpt2_checkpoint_path
,
args
.
gpt2_config_file
,
args
.
pytorch_dump_folder_path
)
pytorch_pretrained_bert/modeling_gpt2.py
0 → 100644
View file @
ffd62382
This diff is collapsed.
Click to expand it.
pytorch_pretrained_bert/tokenization_gpt2.py
0 → 100644
View file @
ffd62382
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from
__future__
import
(
absolute_import
,
division
,
print_function
,
unicode_literals
)
import
json
import
logging
import
os
import
regex
as
re
import
sys
from
io
import
open
from
functools
import
lru_cache
from
tqdm
import
tqdm
from
.file_utils
import
cached_path
from
.tokenization
import
BasicTokenizer
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_VOCAB_ARCHIVE_MAP
=
{
'gpt2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json"
,
}
PRETRAINED_MERGES_ARCHIVE_MAP
=
{
'gpt2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt"
,
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
=
{
'gpt2'
:
1024
,
}
VOCAB_NAME
=
'vocab.json'
MERGES_NAME
=
'merges.txt'
@
lru_cache
()
def
bytes_to_unicode
():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs
=
list
(
range
(
ord
(
"!"
),
ord
(
"~"
)
+
1
))
+
list
(
range
(
ord
(
"¡"
),
ord
(
"¬"
)
+
1
))
+
list
(
range
(
ord
(
"®"
),
ord
(
"ÿ"
)
+
1
))
cs
=
bs
[:]
n
=
0
for
b
in
range
(
2
**
8
):
if
b
not
in
bs
:
bs
.
append
(
b
)
cs
.
append
(
2
**
8
+
n
)
n
+=
1
cs
=
[
chr
(
n
)
for
n
in
cs
]
return
dict
(
zip
(
bs
,
cs
))
def
get_pairs
(
word
):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs
=
set
()
prev_char
=
word
[
0
]
for
char
in
word
[
1
:]:
pairs
.
add
((
prev_char
,
char
))
prev_char
=
char
return
pairs
class
GPT2Tokenizer
(
object
):
"""
GPT-2 BPE tokenizer. Peculiarities:
- Byte-level BPE
- argument special_tokens and function set_special_tokens:
can be used to add additional symbols (ex: "__classify__") to a vocabulary.
"""
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
vocab_file
=
PRETRAINED_VOCAB_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
merges_file
=
PRETRAINED_MERGES_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
vocab_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
VOCAB_NAME
)
merges_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
MERGES_NAME
)
# redirect to the cache, if necessary
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
resolved_merges_file
=
cached_path
(
merges_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
vocab_file
,
merges_file
))
return
None
if
resolved_vocab_file
==
vocab_file
and
resolved_merges_file
==
merges_file
:
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
logger
.
info
(
"loading merges file {}"
.
format
(
merges_file
))
else
:
logger
.
info
(
"loading vocabulary file {} from cache at {}"
.
format
(
vocab_file
,
resolved_vocab_file
))
logger
.
info
(
"loading merges file {} from cache at {}"
.
format
(
merges_file
,
resolved_merges_file
))
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len
=
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
[
pretrained_model_name_or_path
]
kwargs
[
'max_len'
]
=
min
(
kwargs
.
get
(
'max_len'
,
int
(
1e12
)),
max_len
)
# Instantiate tokenizer.
tokenizer
=
cls
(
resolved_vocab_file
,
resolved_merges_file
,
*
inputs
,
**
kwargs
)
return
tokenizer
def
__init__
(
self
,
vocab_file
,
merges_file
,
errors
=
'replace'
,
max_len
=
None
):
self
.
max_len
=
max_len
if
max_len
is
not
None
else
int
(
1e12
)
self
.
encoder
=
json
.
load
(
open
(
vocab_file
))
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
self
.
errors
=
errors
# how to handle errors in decoding
self
.
byte_encoder
=
bytes_to_unicode
()
self
.
byte_decoder
=
{
v
:
k
for
k
,
v
in
self
.
byte_encoder
.
items
()}
bpe_data
=
open
(
merges_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[
1
:
-
1
]
bpe_merges
=
[
tuple
(
merge
.
split
())
for
merge
in
bpe_data
]
self
.
bpe_ranks
=
dict
(
zip
(
bpe_merges
,
range
(
len
(
bpe_merges
))))
self
.
cache
=
{}
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self
.
pat
=
re
.
compile
(
r
"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
def
__len__
(
self
):
return
len
(
self
.
encoder
)
def
bpe
(
self
,
token
):
if
token
in
self
.
cache
:
return
self
.
cache
[
token
]
word
=
tuple
(
token
)
pairs
=
get_pairs
(
word
)
if
not
pairs
:
return
token
while
True
:
bigram
=
min
(
pairs
,
key
=
lambda
pair
:
self
.
bpe_ranks
.
get
(
pair
,
float
(
'inf'
)))
if
bigram
not
in
self
.
bpe_ranks
:
break
first
,
second
=
bigram
new_word
=
[]
i
=
0
while
i
<
len
(
word
):
try
:
j
=
word
.
index
(
first
,
i
)
new_word
.
extend
(
word
[
i
:
j
])
i
=
j
except
:
new_word
.
extend
(
word
[
i
:])
break
if
word
[
i
]
==
first
and
i
<
len
(
word
)
-
1
and
word
[
i
+
1
]
==
second
:
new_word
.
append
(
first
+
second
)
i
+=
2
else
:
new_word
.
append
(
word
[
i
])
i
+=
1
new_word
=
tuple
(
new_word
)
word
=
new_word
if
len
(
word
)
==
1
:
break
else
:
pairs
=
get_pairs
(
word
)
word
=
' '
.
join
(
word
)
self
.
cache
[
token
]
=
word
return
word
def
encode
(
self
,
text
):
bpe_tokens
=
[]
for
token
in
re
.
findall
(
self
.
pat
,
text
):
token
=
''
.
join
(
self
.
byte_encoder
[
b
]
for
b
in
token
.
encode
(
'utf-8'
))
bpe_tokens
.
extend
(
self
.
encoder
[
bpe_token
]
for
bpe_token
in
self
.
bpe
(
token
).
split
(
' '
))
return
bpe_tokens
def
decode
(
self
,
tokens
):
text
=
''
.
join
([
self
.
decoder
[
token
]
for
token
in
tokens
])
text
=
bytearray
([
self
.
byte_decoder
[
c
]
for
c
in
text
]).
decode
(
'utf-8'
,
errors
=
self
.
errors
)
return
text
tests/modeling_gpt2_test.py
0 → 100644
View file @
ffd62382
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
unittest
import
json
import
random
import
torch
from
pytorch_pretrained_bert
import
(
GPT2Config
,
GPT2Model
,
GPT2LMHeadModel
,
GPT2DoubleHeadsModel
)
class
GPT2ModelTest
(
unittest
.
TestCase
):
class
GPT2ModelTester
(
object
):
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
is_training
=
True
,
use_position_ids
=
True
,
use_token_type_ids
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
n_special
=
1
,
n_positions
=
33
,
n_embd
=
32
,
n_layer
=
5
,
n_head
=
4
,
n_choices
=
3
,
type_sequence_label_size
=
2
,
initializer_range
=
0.02
,
num_labels
=
3
,
scope
=
None
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
is_training
=
is_training
self
.
use_position_ids
=
use_position_ids
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
n_special
=
n_special
self
.
n_positions
=
n_positions
self
.
n_embd
=
n_embd
self
.
n_layer
=
n_layer
self
.
n_head
=
n_head
self
.
n_choices
=
n_choices
self
.
type_sequence_label_size
=
type_sequence_label_size
self
.
initializer_range
=
initializer_range
self
.
num_labels
=
num_labels
self
.
scope
=
scope
def
prepare_config_and_inputs
(
self
):
input_ids
=
GPT2ModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
],
self
.
vocab_size
)
position_ids
=
None
if
self
.
use_position_ids
:
position_ids
=
GPT2ModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
],
self
.
n_positions
)
token_type_ids
=
None
if
self
.
use_token_type_ids
:
total_voc
=
self
.
vocab_size
+
self
.
n_special
token_type_ids
=
GPT2ModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
],
total_voc
)
mc_labels
=
None
lm_labels
=
None
mc_token_ids
=
None
if
self
.
use_labels
:
mc_labels
=
GPT2ModelTest
.
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
lm_labels
=
GPT2ModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
],
self
.
num_labels
)
mc_token_ids
=
GPT2ModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
],
self
.
seq_length
)
config
=
GPT2Config
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
n_positions
=
self
.
n_positions
,
n_special
=
self
.
n_special
,
n_embd
=
self
.
n_embd
,
n_layer
=
self
.
n_layer
,
n_head
=
self
.
n_head
,
initializer_range
=
self
.
initializer_range
)
return
(
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
)
def
create_gpt2_model
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
model
=
GPT2Model
(
config
)
model
.
eval
()
hidden_states
,
presents
=
model
(
input_ids
,
position_ids
,
token_type_ids
)
outputs
=
{
"hidden_states"
:
hidden_states
,
"presents"
:
presents
,
}
return
outputs
def
check_gpt2_model_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"hidden_states"
].
size
()),
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
self
.
n_embd
])
def
create_gpt2_lm_head
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
model
=
GPT2LMHeadModel
(
config
)
model
.
eval
()
loss
=
model
(
input_ids
,
position_ids
,
token_type_ids
,
lm_labels
)
lm_logits
,
presents
=
model
(
input_ids
,
position_ids
,
token_type_ids
)
outputs
=
{
"loss"
:
loss
,
"lm_logits"
:
lm_logits
,
"presents"
:
presents
,
}
return
outputs
def
check_gpt2_lm_head_output
(
self
,
result
):
total_voc
=
self
.
n_special
+
self
.
vocab_size
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits"
].
size
()),
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
total_voc
])
def
check_gpt2_lm_head_loss_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
def
create_gpt2_double_heads
(
self
,
config
,
input_ids
,
token_type_ids
,
position_ids
,
mc_labels
,
lm_labels
,
mc_token_ids
):
model
=
GPT2DoubleHeadsModel
(
config
)
model
.
eval
()
loss
=
model
(
input_ids
,
mc_token_ids
,
lm_labels
=
lm_labels
,
mc_labels
=
mc_labels
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
)
lm_logits
,
mc_logits
,
presents
=
model
(
input_ids
,
mc_token_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
)
outputs
=
{
"loss"
:
loss
,
"lm_logits"
:
lm_logits
,
"mc_logits"
:
mc_logits
,
"presents"
:
presents
,
}
return
outputs
def
check_gpt2_double_heads_output
(
self
,
result
):
total_voc
=
self
.
n_special
+
self
.
vocab_size
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits"
].
size
()),
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
total_voc
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"mc_logits"
].
size
()),
[
self
.
batch_size
,
self
.
n_choices
])
def
check_gpt2_double_heads_loss_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
[
list
(
l
.
size
())
for
l
in
result
[
"loss"
]],
[[],
[]])
def
test_default
(
self
):
self
.
run_tester
(
GPT2ModelTest
.
GPT2ModelTester
(
self
))
def
test_config_to_json_string
(
self
):
config
=
GPT2Config
(
vocab_size_or_config_json_file
=
99
,
n_embd
=
37
)
obj
=
json
.
loads
(
config
.
to_json_string
())
self
.
assertEqual
(
obj
[
"vocab_size"
],
99
)
self
.
assertEqual
(
obj
[
"n_embd"
],
37
)
def
run_tester
(
self
,
tester
):
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
output_result
=
tester
.
create_gpt2_model
(
*
config_and_inputs
)
tester
.
check_gpt2_model_output
(
output_result
)
output_result
=
tester
.
create_gpt2_lm_head
(
*
config_and_inputs
)
tester
.
check_gpt2_lm_head_output
(
output_result
)
tester
.
check_gpt2_lm_head_loss_output
(
output_result
)
output_result
=
tester
.
create_gpt2_double_heads
(
*
config_and_inputs
)
tester
.
check_gpt2_double_heads_output
(
output_result
)
tester
.
check_gpt2_double_heads_loss_output
(
output_result
)
@
classmethod
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
"""Creates a random int32 tensor of the shape within the vocab size."""
if
rng
is
None
:
rng
=
random
.
Random
()
total_dims
=
1
for
dim
in
shape
:
total_dims
*=
dim
values
=
[]
for
_
in
range
(
total_dims
):
values
.
append
(
rng
.
randint
(
0
,
vocab_size
-
1
))
return
torch
.
tensor
(
data
=
values
,
dtype
=
torch
.
long
).
view
(
shape
).
contiguous
()
if
__name__
==
"__main__"
:
unittest
.
main
()
tests/tokenization_gpt2_test.py
0 → 100644
View file @
ffd62382
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
os
import
unittest
import
json
from
pytorch_pretrained_bert.tokenization_gpt2
import
GPT2Tokenizer
class
GPT2TokenizationTest
(
unittest
.
TestCase
):
def
test_full_tokenizer
(
self
):
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
vocab
=
[
"l"
,
"o"
,
"w"
,
"e"
,
"r"
,
"s"
,
"t"
,
"i"
,
"d"
,
"n"
,
"w</w>"
,
"r</w>"
,
"t</w>"
,
"lo"
,
"low"
,
"er</w>"
,
"low</w>"
,
"lowest</w>"
,
"newer</w>"
,
"wider</w>"
]
vocab_tokens
=
dict
(
zip
(
vocab
,
range
(
len
(
vocab
))))
merges
=
[
"#version: 0.2"
,
"l o"
,
"lo w"
,
"e r</w>"
,
""
]
with
open
(
"/tmp/openai_tokenizer_vocab_test.json"
,
"w"
)
as
fp
:
json
.
dump
(
vocab_tokens
,
fp
)
vocab_file
=
fp
.
name
with
open
(
"/tmp/openai_tokenizer_merges_test.txt"
,
"w"
)
as
fp
:
fp
.
write
(
"
\n
"
.
join
(
merges
))
merges_file
=
fp
.
name
tokenizer
=
GPT2Tokenizer
(
vocab_file
,
merges_file
)
os
.
remove
(
vocab_file
)
os
.
remove
(
merges_file
)
text
=
"lower"
bpe_tokens
=
[
"low"
,
"er</w>"
]
tokens
=
tokenizer
.
tokenize
(
text
)
self
.
assertListEqual
(
tokens
,
bpe_tokens
)
input_tokens
=
tokens
input_bpe_tokens
=
[
14
,
15
,
20
]
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
if
__name__
==
'__main__'
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment