Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
98c96fb1
Commit
98c96fb1
authored
Jan 29, 2019
by
thomwolf
Browse files
splitting position and tokens embeddings in OpenAI GPT - updating tf imports - tests
parent
5456d823
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
66 additions
and
44 deletions
+66
-44
pytorch_pretrained_bert/__main__.py
pytorch_pretrained_bert/__main__.py
+2
-2
pytorch_pretrained_bert/convert_openai_checkpoint_to_pytorch.py
...h_pretrained_bert/convert_openai_checkpoint_to_pytorch.py
+3
-6
pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py
pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py
+1
-1
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+8
-0
pytorch_pretrained_bert/modeling_openai.py
pytorch_pretrained_bert/modeling_openai.py
+38
-27
pytorch_pretrained_bert/modeling_transfo_xl.py
pytorch_pretrained_bert/modeling_transfo_xl.py
+7
-0
tests/modeling_openai_test.py
tests/modeling_openai_test.py
+7
-8
No files found.
pytorch_pretrained_bert/__main__.py
View file @
98c96fb1
...
@@ -14,7 +14,7 @@ def main():
...
@@ -14,7 +14,7 @@ def main():
else
:
else
:
if
sys
.
argv
[
1
]
==
"convert_tf_checkpoint_to_pytorch"
:
if
sys
.
argv
[
1
]
==
"convert_tf_checkpoint_to_pytorch"
:
try
:
try
:
from
.convert_tf_checkpoint_to_pytorch
import
convert_tf_checkpoint_to_pytorch
import
tensorflow
as
tf
except
ModuleNotFoundError
:
except
ModuleNotFoundError
:
print
(
"pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
print
(
"pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"In that case, it requires TensorFlow to be installed. Please see "
...
@@ -42,7 +42,7 @@ def main():
...
@@ -42,7 +42,7 @@ def main():
PYTORCH_DUMP_OUTPUT
)
PYTORCH_DUMP_OUTPUT
)
else
:
else
:
try
:
try
:
from
.convert_transfo_xl_checkpoint_to_pytorch
import
convert_transfo_xl_checkpoint_to_pytorch
import
tensorflow
as
tf
except
ModuleNotFoundError
:
except
ModuleNotFoundError
:
print
(
"pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
print
(
"pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"In that case, it requires TensorFlow to be installed. Please see "
...
...
pytorch_pretrained_bert/convert_openai_checkpoint_to_pytorch.py
View file @
98c96fb1
...
@@ -18,13 +18,10 @@ from __future__ import absolute_import
...
@@ -18,13 +18,10 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
re
import
json
import
argparse
import
argparse
import
torch
import
torch
import
numpy
as
np
from
.modeling_openai
import
load_tf_weights_in_openai_gpt
,
OpenAIGPTConfig
,
OpenAIGPTModel
,
CONFIG_NAME
,
WEIGHTS_NAME
from
pytorch_pretrained_bert
.modeling_openai
import
load_tf_weights_in_openai_gpt
,
OpenAIGPTConfig
,
OpenAIGPTModel
,
CONFIG_NAME
,
WEIGHTS_NAME
def
convert_openai_checkpoint_to_pytorch
(
openai_checkpoint_folder_path
,
openai_config_file
,
pytorch_dump_folder_path
):
def
convert_openai_checkpoint_to_pytorch
(
openai_checkpoint_folder_path
,
openai_config_file
,
pytorch_dump_folder_path
):
# Construct model
# Construct model
...
@@ -67,5 +64,5 @@ if __name__ == "__main__":
...
@@ -67,5 +64,5 @@ if __name__ == "__main__":
"This specifies the model architecture."
)
"This specifies the model architecture."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
convert_openai_checkpoint_to_pytorch
(
args
.
openai_checkpoint_folder_path
,
convert_openai_checkpoint_to_pytorch
(
args
.
openai_checkpoint_folder_path
,
args
.
pytorch_dump_folder_path
,
args
.
openai_config_file
,
args
.
openai_config_file
)
args
.
pytorch_dump_folder_path
)
pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py
View file @
98c96fb1
...
@@ -25,7 +25,7 @@ import tensorflow as tf
...
@@ -25,7 +25,7 @@ import tensorflow as tf
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
from
.modeling
import
BertConfig
,
BertForPreTraining
,
load_tf_weights_in_bert
from
pytorch_pretrained_bert
.modeling
import
BertConfig
,
BertForPreTraining
,
load_tf_weights_in_bert
def
convert_tf_checkpoint_to_pytorch
(
tf_checkpoint_path
,
bert_config_file
,
pytorch_dump_path
):
def
convert_tf_checkpoint_to_pytorch
(
tf_checkpoint_path
,
bert_config_file
,
pytorch_dump_path
):
# Initialise PyTorch model
# Initialise PyTorch model
...
...
pytorch_pretrained_bert/modeling.py
View file @
98c96fb1
...
@@ -52,6 +52,14 @@ TF_WEIGHTS_NAME = 'model.ckpt'
...
@@ -52,6 +52,14 @@ TF_WEIGHTS_NAME = 'model.ckpt'
def
load_tf_weights_in_bert
(
model
,
tf_checkpoint_path
):
def
load_tf_weights_in_bert
(
model
,
tf_checkpoint_path
):
""" Load tf checkpoints in a pytorch model
""" Load tf checkpoints in a pytorch model
"""
"""
try
:
import
re
import
numpy
as
np
import
tensorflow
as
tf
except
ModuleNotFoundError
:
print
(
"Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
tf_path
=
os
.
path
.
abspath
(
tf_checkpoint_path
)
tf_path
=
os
.
path
.
abspath
(
tf_checkpoint_path
)
print
(
"Converting TensorFlow checkpoint from {}"
.
format
(
tf_path
))
print
(
"Converting TensorFlow checkpoint from {}"
.
format
(
tf_path
))
# Load weights from TF model
# Load weights from TF model
...
...
pytorch_pretrained_bert/modeling_openai.py
View file @
98c96fb1
...
@@ -15,23 +15,23 @@
...
@@ -15,23 +15,23 @@
# limitations under the License.
# limitations under the License.
"""PyTorch OpenAI GPT model."""
"""PyTorch OpenAI GPT model."""
import
o
s
import
collection
s
import
copy
import
copy
import
json
import
json
import
math
import
logging
import
logging
import
math
import
os
import
shutil
import
tarfile
import
tarfile
import
tempfile
import
tempfile
import
shutil
import
collections
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.nn
import
CrossEntropyLoss
from
torch.nn
import
CrossEntropyLoss
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
.modeling
import
BertLayerNorm
as
LayerNorm
from
.file_utils
import
cached_path
from
.file_utils
import
cached_path
from
.modeling
import
BertLayerNorm
as
LayerNorm
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -42,6 +42,8 @@ WEIGHTS_NAME = "pytorch_model.bin"
...
@@ -42,6 +42,8 @@ WEIGHTS_NAME = "pytorch_model.bin"
def
load_tf_weights_in_openai_gpt
(
model
,
openai_checkpoint_folder_path
):
def
load_tf_weights_in_openai_gpt
(
model
,
openai_checkpoint_folder_path
):
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
"""
"""
import
re
import
numpy
as
np
print
(
"Loading weights..."
)
print
(
"Loading weights..."
)
names
=
json
.
load
(
open
(
openai_checkpoint_folder_path
+
'/parameters_names.json'
,
"r"
,
encoding
=
'utf-8'
))
names
=
json
.
load
(
open
(
openai_checkpoint_folder_path
+
'/parameters_names.json'
,
"r"
,
encoding
=
'utf-8'
))
shapes
=
json
.
load
(
open
(
openai_checkpoint_folder_path
+
'/params_shapes.json'
,
"r"
,
encoding
=
'utf-8'
))
shapes
=
json
.
load
(
open
(
openai_checkpoint_folder_path
+
'/params_shapes.json'
,
"r"
,
encoding
=
'utf-8'
))
...
@@ -50,18 +52,24 @@ def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
...
@@ -50,18 +52,24 @@ def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
init_params
=
np
.
split
(
np
.
concatenate
(
init_params
,
0
),
offsets
)[:
-
1
]
init_params
=
np
.
split
(
np
.
concatenate
(
init_params
,
0
),
offsets
)[:
-
1
]
init_params
=
[
param
.
reshape
(
shape
)
for
param
,
shape
in
zip
(
init_params
,
shapes
)]
init_params
=
[
param
.
reshape
(
shape
)
for
param
,
shape
in
zip
(
init_params
,
shapes
)]
init_params
[
0
]
=
np
.
concatenate
([
init_params
[
1
],
init_params
[
0
]],
0
)
# Thsi as used when we had a single embedding matrix for positions and tokens
del
init_params
[
1
]
# init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
# del init_params[1]
init_params
=
[
arr
.
squeeze
()
for
arr
in
init_params
]
init_params
=
[
arr
.
squeeze
()
for
arr
in
init_params
]
try
:
try
:
assert
model
.
embed
.
weight
.
shape
==
init_params
[
0
].
shape
assert
model
.
tokens_embed
.
weight
.
shape
==
init_params
[
1
].
shape
assert
model
.
positions_embed
.
weight
.
shape
==
init_params
[
0
].
shape
except
AssertionError
as
e
:
except
AssertionError
as
e
:
e
.
args
+=
(
model
.
embed
.
weight
.
shape
,
init_params
[
0
].
shape
)
e
.
args
+=
(
model
.
tokens_embed
.
weight
.
shape
,
init_params
[
1
].
shape
)
e
.
args
+=
(
model
.
positions_embed
.
weight
.
shape
,
init_params
[
0
].
shape
)
raise
raise
model
.
embed
.
weight
.
data
=
torch
.
from_numpy
(
init_params
[
0
])
model
.
tokens_embed
.
weight
.
data
=
torch
.
from_numpy
(
init_params
[
1
])
model
.
positions_embed
.
weight
.
data
=
torch
.
from_numpy
(
init_params
[
0
])
names
.
pop
(
0
)
names
.
pop
(
0
)
# Pop position and token embedding arrays
init_params
.
pop
(
0
)
init_params
.
pop
(
0
)
init_params
.
pop
(
0
)
for
name
,
array
in
zip
(
names
,
init_params
):
# names[1:n_transfer], init_params[1:n_transfer]):
for
name
,
array
in
zip
(
names
,
init_params
):
# names[1:n_transfer], init_params[1:n_transfer]):
...
@@ -584,8 +592,9 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
...
@@ -584,8 +592,9 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
OpenAIGPTModel
,
self
).
__init__
(
config
)
super
(
OpenAIGPTModel
,
self
).
__init__
(
config
)
total_embeddings_size
=
config
.
vocab_size
+
config
.
n_special
+
config
.
n_positions
num_tokens
=
config
.
vocab_size
+
config
.
n_special
self
.
embed
=
nn
.
Embedding
(
total_embeddings_size
,
config
.
n_embd
)
self
.
tokens_embed
=
nn
.
Embedding
(
num_tokens
,
config
.
n_embd
)
self
.
positions_embed
=
nn
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
)
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
)
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
)
self
.
h
=
nn
.
ModuleList
([
copy
.
deepcopy
(
block
)
for
_
in
range
(
config
.
n_layer
)])
self
.
h
=
nn
.
ModuleList
([
copy
.
deepcopy
(
block
)
for
_
in
range
(
config
.
n_layer
)])
...
@@ -598,30 +607,32 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
...
@@ -598,30 +607,32 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
# Update config
# Update config
self
.
config
.
n_special
=
num_special_tokens
self
.
config
.
n_special
=
num_special_tokens
# # Build new embeddings and initialize
# # Build new embeddings and initialize
old_embed
=
self
.
embed
old_embed
=
self
.
tokens_
embed
self
.
embed
=
nn
.
Embedding
(
self
.
config
.
total_num_embeddings
,
self
.
config
.
n_embd
)
self
.
tokens_
embed
=
nn
.
Embedding
(
self
.
config
.
total_num_embeddings
,
self
.
config
.
n_embd
)
# Initialize all new embeddings (in particular the special tokens)
# Initialize all new embeddings (in particular the special tokens)
self
.
init_weights
(
self
.
embed
)
self
.
init_weights
(
self
.
tokens_
embed
)
# Copy word and positional embeddings from the previous weights
# Copy word and positional embeddings from the previous weights
self
.
embed
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
=
old_embed
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
self
.
tokens_
embed
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
=
old_embed
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
self
.
embed
.
weight
.
data
[
-
self
.
config
.
n_positions
:,
:]
=
old_embed
.
weight
.
data
[
-
self
.
config
.
n_positions
:,
:]
self
.
tokens_
embed
.
weight
.
data
[
-
self
.
config
.
n_positions
:,
:]
=
old_embed
.
weight
.
data
[
-
self
.
config
.
n_positions
:,
:]
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
):
if
position_ids
is
None
:
if
position_ids
is
None
:
start
=
self
.
config
.
vocab_size
+
self
.
config
.
n_special
# This was used when we had a single embedding matrice from position and token embeddings
end
=
start
+
input_ids
.
size
(
-
1
)
# start = self.config.vocab_size + self.config.n_special
position_ids
=
torch
.
arange
(
start
,
end
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
# end = start + input_ids.size(-1)
# position_ids = torch.arange(start, end, dtype=torch.long, device=input_ids.device)
position_ids
=
torch
.
arange
(
input_ids
.
size
(
-
1
),
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
input_ids
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
input_ids
)
input_shape
=
input_ids
.
size
()
input_shape
=
input_ids
.
size
()
input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
position_ids
=
position_ids
.
view
(
-
1
,
position_ids
.
size
(
-
1
))
position_ids
=
position_ids
.
view
(
-
1
,
position_ids
.
size
(
-
1
))
inputs_embeds
=
self
.
embed
(
input_ids
)
inputs_embeds
=
self
.
tokens_
embed
(
input_ids
)
position_embeds
=
self
.
embed
(
position_ids
)
position_embeds
=
self
.
positions_
embed
(
position_ids
)
if
token_type_ids
is
not
None
:
if
token_type_ids
is
not
None
:
token_type_ids
=
token_type_ids
.
view
(
-
1
,
token_type_ids
.
size
(
-
1
))
token_type_ids
=
token_type_ids
.
view
(
-
1
,
token_type_ids
.
size
(
-
1
))
token_type_embeds
=
self
.
embed
(
token_type_ids
)
token_type_embeds
=
self
.
tokens_
embed
(
token_type_ids
)
else
:
else
:
token_type_embeds
=
0
token_type_embeds
=
0
# Add the position information to the input embeddings
# Add the position information to the input embeddings
...
@@ -694,13 +705,13 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
...
@@ -694,13 +705,13 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
OpenAIGPTLMHeadModel
,
self
).
__init__
(
config
)
super
(
OpenAIGPTLMHeadModel
,
self
).
__init__
(
config
)
self
.
transformer
=
OpenAIGPTModel
(
config
)
self
.
transformer
=
OpenAIGPTModel
(
config
)
self
.
lm_head
=
OpenAIGPTLMHead
(
self
.
transformer
.
embed
.
weight
,
config
)
self
.
lm_head
=
OpenAIGPTLMHead
(
self
.
transformer
.
tokens_
embed
.
weight
,
config
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
def
set_num_special_tokens
(
self
,
num_special_tokens
):
def
set_num_special_tokens
(
self
,
num_special_tokens
):
" Update input and output embeddings with new embedding matrice "
" Update input and output embeddings with new embedding matrice "
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
embed
.
weight
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
tokens_
embed
.
weight
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
lm_labels
=
None
):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
lm_labels
=
None
):
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
)
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
)
...
@@ -780,14 +791,14 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
...
@@ -780,14 +791,14 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
OpenAIGPTDoubleHeadsModel
,
self
).
__init__
(
config
)
super
(
OpenAIGPTDoubleHeadsModel
,
self
).
__init__
(
config
)
self
.
transformer
=
OpenAIGPTModel
(
config
)
self
.
transformer
=
OpenAIGPTModel
(
config
)
self
.
lm_head
=
OpenAIGPTLMHead
(
self
.
transformer
.
embed
.
weight
,
config
)
self
.
lm_head
=
OpenAIGPTLMHead
(
self
.
transformer
.
tokens_
embed
.
weight
,
config
)
self
.
multiple_choice_head
=
OpenAIGPTMultipleChoiceHead
(
config
)
self
.
multiple_choice_head
=
OpenAIGPTMultipleChoiceHead
(
config
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
def
set_num_special_tokens
(
self
,
num_special_tokens
):
def
set_num_special_tokens
(
self
,
num_special_tokens
):
" Update input and output embeddings with new embedding matrice "
" Update input and output embeddings with new embedding matrice "
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
embed
.
weight
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
tokens_
embed
.
weight
)
def
forward
(
self
,
input_ids
,
mc_token_mask
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
position_ids
=
None
):
def
forward
(
self
,
input_ids
,
mc_token_mask
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
position_ids
=
None
):
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
)
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
)
...
...
pytorch_pretrained_bert/modeling_transfo_xl.py
View file @
98c96fb1
...
@@ -121,6 +121,13 @@ def build_tf_to_pytorch_map(model, config):
...
@@ -121,6 +121,13 @@ def build_tf_to_pytorch_map(model, config):
def
load_tf_weights_in_transfo_xl
(
model
,
config
,
tf_path
):
def
load_tf_weights_in_transfo_xl
(
model
,
config
,
tf_path
):
""" Load tf checkpoints in a pytorch model
""" Load tf checkpoints in a pytorch model
"""
"""
try
:
import
numpy
as
np
import
tensorflow
as
tf
except
ModuleNotFoundError
:
print
(
"Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
# Build TF to PyTorch weights loading map
# Build TF to PyTorch weights loading map
tf_to_pt_map
=
build_tf_to_pytorch_map
(
model
,
config
)
tf_to_pt_map
=
build_tf_to_pytorch_map
(
model
,
config
)
...
...
tests/modeling_openai_test.py
View file @
98c96fb1
...
@@ -39,7 +39,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
...
@@ -39,7 +39,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
use_labels
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
vocab_size
=
99
,
n_special
=
1
,
n_special
=
1
,
n_
ctx
=
33
,
n_
positions
=
33
,
n_embd
=
32
,
n_embd
=
32
,
n_layer
=
5
,
n_layer
=
5
,
n_head
=
4
,
n_head
=
4
,
...
@@ -61,7 +61,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
...
@@ -61,7 +61,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
self
.
use_labels
=
use_labels
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
n_special
=
n_special
self
.
n_special
=
n_special
self
.
n_
ctx
=
n_ctx
self
.
n_
positions
=
n_positions
self
.
n_embd
=
n_embd
self
.
n_embd
=
n_embd
self
.
n_layer
=
n_layer
self
.
n_layer
=
n_layer
self
.
n_head
=
n_head
self
.
n_head
=
n_head
...
@@ -80,12 +80,11 @@ class OpenAIGPTModelTest(unittest.TestCase):
...
@@ -80,12 +80,11 @@ class OpenAIGPTModelTest(unittest.TestCase):
position_ids
=
None
position_ids
=
None
if
self
.
use_position_ids
:
if
self
.
use_position_ids
:
position_ids
=
OpenAIGPTModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
],
self
.
n_ctx
)
position_ids
=
OpenAIGPTModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
],
self
.
n_positions
)
position_ids
=
position_ids
+
self
.
n_special
+
self
.
vocab_size
token_type_ids
=
None
token_type_ids
=
None
if
self
.
use_token_type_ids
:
if
self
.
use_token_type_ids
:
total_voc
=
self
.
n_ctx
+
self
.
n_special
+
self
.
vocab_size
total_voc
=
self
.
vocab_size
+
self
.
n_special
token_type_ids
=
OpenAIGPTModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
],
total_voc
)
token_type_ids
=
OpenAIGPTModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
],
total_voc
)
mc_labels
=
None
mc_labels
=
None
...
@@ -98,7 +97,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
...
@@ -98,7 +97,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
config
=
OpenAIGPTConfig
(
config
=
OpenAIGPTConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
vocab_size_or_config_json_file
=
self
.
vocab_size
,
n_
ctx
=
self
.
n_ctx
,
n_
positions
=
self
.
n_positions
,
n_special
=
self
.
n_special
,
n_special
=
self
.
n_special
,
n_embd
=
self
.
n_embd
,
n_embd
=
self
.
n_embd
,
n_layer
=
self
.
n_layer
,
n_layer
=
self
.
n_layer
,
...
@@ -139,7 +138,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
...
@@ -139,7 +138,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
return
outputs
return
outputs
def
check_openai_lm_head_output
(
self
,
result
):
def
check_openai_lm_head_output
(
self
,
result
):
total_voc
=
self
.
n_ctx
+
self
.
n_special
+
self
.
vocab_size
total_voc
=
self
.
n_special
+
self
.
vocab_size
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits"
].
size
()),
list
(
result
[
"lm_logits"
].
size
()),
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
total_voc
])
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
total_voc
])
...
@@ -164,7 +163,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
...
@@ -164,7 +163,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
return
outputs
return
outputs
def
check_openai_double_heads_output
(
self
,
result
):
def
check_openai_double_heads_output
(
self
,
result
):
total_voc
=
self
.
n_ctx
+
self
.
n_special
+
self
.
vocab_size
total_voc
=
self
.
n_special
+
self
.
vocab_size
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits"
].
size
()),
list
(
result
[
"lm_logits"
].
size
()),
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
total_voc
])
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
total_voc
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment