Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
45709d75
Commit
45709d75
authored
Jun 21, 2019
by
thomwolf
Browse files
model running with simple inputs
parent
b407972e
Changes
7
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
404 additions
and
108 deletions
+404
-108
pytorch_pretrained_bert/__init__.py
pytorch_pretrained_bert/__init__.py
+3
-0
pytorch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
...ch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
+2
-2
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+1
-1
pytorch_pretrained_bert/modeling_gpt2.py
pytorch_pretrained_bert/modeling_gpt2.py
+1
-1
pytorch_pretrained_bert/modeling_openai.py
pytorch_pretrained_bert/modeling_openai.py
+1
-1
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+149
-103
tests/modeling_xlnet_test.py
tests/modeling_xlnet_test.py
+247
-0
No files found.
pytorch_pretrained_bert/__init__.py
View file @
45709d75
...
...
@@ -17,6 +17,9 @@ from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHe
from
.modeling_gpt2
import
(
GPT2Config
,
GPT2Model
,
GPT2LMHeadModel
,
GPT2DoubleHeadsModel
,
GPT2MultipleChoiceHead
,
load_tf_weights_in_gpt2
)
from
.modeling_xlnet
import
(
XLNetBaseConfig
,
XLNetConfig
,
XLNetRunConfig
,
XLNetPreTrainedModel
,
XLNetModel
,
XLNetLMHeadModel
,
load_tf_weights_in_xlnet
)
from
.optimization
import
BertAdam
from
.optimization_openai
import
OpenAIAdam
...
...
pytorch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
View file @
45709d75
...
...
@@ -21,13 +21,13 @@ from __future__ import print_function
import
argparse
import
torch
from
pytorch_pretrained_bert.modeling_xlnet
import
XLNetConfig
,
XLNetRunConfig
,
XLNetModel
,
load_tf_weights_in_xlnet
from
pytorch_pretrained_bert.modeling_xlnet
import
XLNetConfig
,
XLNetRunConfig
,
XLNet
LMHead
Model
,
load_tf_weights_in_xlnet
def
convert_xlnet_checkpoint_to_pytorch
(
tf_checkpoint_path
,
bert_config_file
,
pytorch_dump_path
):
# Initialise PyTorch model
config
=
XLNetConfig
.
from_json_file
(
bert_config_file
)
print
(
"Building PyTorch model from configuration: {}"
.
format
(
str
(
config
)))
model
=
XLNetModel
(
config
)
model
=
XLNet
LMHead
Model
(
config
)
# Load weights from tf checkpoint
load_tf_weights_in_xlnet
(
model
,
tf_checkpoint_path
)
...
...
pytorch_pretrained_bert/modeling.py
View file @
45709d75
...
...
@@ -867,7 +867,7 @@ class BertModel(BertPreTrainedModel):
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
_as
(
self
.
config
.
num_hidden_layers
,
-
1
,
-
1
,
-
1
,
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
num_hidden_layers
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
...
...
pytorch_pretrained_bert/modeling_gpt2.py
View file @
45709d75
...
...
@@ -722,7 +722,7 @@ class GPT2Model(GPT2PreTrainedModel):
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
_as
(
self
.
config
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
...
...
pytorch_pretrained_bert/modeling_openai.py
View file @
45709d75
...
...
@@ -718,7 +718,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
_as
(
self
.
config
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
...
...
pytorch_pretrained_bert/modeling_xlnet.py
View file @
45709d75
This diff is collapsed.
Click to expand it.
tests/modeling_xlnet_test.py
0 → 100644
View file @
45709d75
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
unittest
import
json
import
random
import
shutil
import
pytest
import
torch
from
pytorch_pretrained_bert
import
(
XLNetConfig
,
XLNetRunConfig
,
XLNetModel
,
XLNetLMHeadModel
)
from
pytorch_pretrained_bert.modeling_xlnet
import
PRETRAINED_MODEL_ARCHIVE_MAP
class
XLNetModelTest
(
unittest
.
TestCase
):
class
XLNetModelTester
(
object
):
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
mem_len
=
30
,
clamp_len
=
15
,
reuse_len
=
15
,
is_training
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
cutoffs
=
[
10
,
50
,
80
],
d_model
=
32
,
n_head
=
4
,
d_inner
=
128
,
n_layer
=
5
,
max_position_embeddings
=
10
,
untie_r
=
True
,
bi_data
=
False
,
same_length
=
False
,
seed
=
1
,
type_vocab_size
=
2
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
mem_len
=
mem_len
self
.
clamp_len
=
clamp_len
self
.
reuse_len
=
reuse_len
self
.
is_training
=
is_training
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
cutoffs
=
cutoffs
self
.
d_model
=
d_model
self
.
n_head
=
n_head
self
.
d_inner
=
d_inner
self
.
n_layer
=
n_layer
self
.
max_position_embeddings
=
max_position_embeddings
self
.
bi_data
=
bi_data
self
.
untie_r
=
untie_r
self
.
same_length
=
same_length
self
.
seed
=
seed
self
.
type_vocab_size
=
type_vocab_size
def
prepare_config_and_inputs
(
self
):
input_ids_1
=
XLNetModelTest
.
ids_tensor
([
self
.
seq_length
,
self
.
batch_size
],
self
.
vocab_size
)
input_ids_2
=
XLNetModelTest
.
ids_tensor
([
self
.
seq_length
,
self
.
batch_size
],
self
.
vocab_size
)
segment_ids
=
XLNetModelTest
.
ids_tensor
([
self
.
seq_length
,
self
.
batch_size
],
self
.
type_vocab_size
)
lm_labels
=
None
if
self
.
use_labels
:
lm_labels
=
XLNetModelTest
.
ids_tensor
([
self
.
seq_length
,
self
.
batch_size
],
self
.
vocab_size
)
config
=
XLNetConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
d_model
=
self
.
d_model
,
n_head
=
self
.
n_head
,
d_inner
=
self
.
d_inner
,
n_layer
=
self
.
n_layer
,
untie_r
=
self
.
untie_r
,
max_position_embeddings
=
self
.
max_position_embeddings
)
run_config
=
XLNetRunConfig
(
mem_len
=
self
.
mem_len
,
clamp_len
=
self
.
clamp_len
,
same_length
=
self
.
same_length
,
reuse_len
=
self
.
reuse_len
,
bi_data
=
self
.
bi_data
)
config
.
update
(
run_config
)
return
(
config
,
input_ids_1
,
input_ids_2
,
segment_ids
,
lm_labels
)
def
set_seed
(
self
):
random
.
seed
(
self
.
seed
)
torch
.
manual_seed
(
self
.
seed
)
def
create_transfo_xl_model
(
self
,
config
,
input_ids_1
,
input_ids_2
,
segment_ids
,
lm_labels
):
model
=
XLNetLMHeadModel
(
config
)
model
.
eval
()
hidden_states_1
,
mems_1
=
model
(
input_ids_1
,
seg_id
=
segment_ids
)
hidden_states_2
,
mems_2
=
model
(
input_ids_2
,
seg_id
=
segment_ids
,
mems
=
mems_1
)
outputs
=
{
"hidden_states_1"
:
hidden_states_1
,
"mems_1"
:
mems_1
,
"hidden_states_2"
:
hidden_states_2
,
"mems_2"
:
mems_2
,
}
return
outputs
def
check_transfo_xl_model_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"hidden_states_1"
].
size
()),
[
self
.
seq_length
,
self
.
batch_size
,
self
.
d_model
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"hidden_states_2"
].
size
()),
[
self
.
seq_length
,
self
.
batch_size
,
self
.
d_model
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
def
create_transfo_xl_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
segment_ids
,
lm_labels
):
model
=
XLNetLMHeadModel
(
config
)
model
.
eval
()
loss_1
,
mems_1a
=
model
(
input_ids_1
,
target
=
lm_labels
)
lm_logits_1
,
mems_1b
=
model
(
input_ids_1
)
loss_2
,
mems_2a
=
model
(
input_ids_2
,
target
=
lm_labels
,
mems
=
mems_1a
)
lm_logits_2
,
mems_2b
=
model
(
input_ids_2
,
mems
=
mems_1b
)
outputs
=
{
"loss_1"
:
loss_1
,
"mems_1a"
:
mems_1a
,
"lm_logits_1"
:
lm_logits_1
,
"mems_1b"
:
mems_1b
,
"loss_2"
:
loss_2
,
"mems_2a"
:
mems_2a
,
"lm_logits_2"
:
lm_logits_2
,
"mems_2b"
:
mems_2b
,
}
return
outputs
def
check_transfo_xl_lm_head_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss_1"
].
size
()),
[
self
.
seq_length
,
self
.
batch_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits_1"
].
size
()),
[
self
.
seq_length
,
self
.
batch_size
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1a"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1b"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_1a"
]),
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_1b"
]))
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss_2"
].
size
()),
[
self
.
seq_length
,
self
.
batch_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits_2"
].
size
()),
[
self
.
seq_length
,
self
.
batch_size
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2a"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2b"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_2a"
]),
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_2b"
]))
def
test_default
(
self
):
self
.
run_tester
(
XLNetModelTest
.
XLNetModelTester
(
self
))
def
test_config_to_json_string
(
self
):
config
=
XLNetConfig
(
vocab_size_or_config_json_file
=
96
,
d_model
=
37
)
obj
=
json
.
loads
(
config
.
to_json_string
())
self
.
assertEqual
(
obj
[
"n_token"
],
96
)
self
.
assertEqual
(
obj
[
"d_model"
],
37
)
def
test_config_to_json_file
(
self
):
config_first
=
XLNetConfig
(
vocab_size_or_config_json_file
=
96
,
d_model
=
37
)
json_file_path
=
"/tmp/config.json"
config_first
.
to_json_file
(
json_file_path
)
config_second
=
XLNetConfig
.
from_json_file
(
json_file_path
)
os
.
remove
(
json_file_path
)
self
.
assertEqual
(
config_second
.
to_dict
(),
config_first
.
to_dict
())
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_pretrained_bert_test/"
for
model_name
in
list
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
XLNetModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
def
run_tester
(
self
,
tester
):
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
set_seed
()
output_result
=
tester
.
create_transfo_xl_model
(
*
config_and_inputs
)
tester
.
check_transfo_xl_model_output
(
output_result
)
tester
.
set_seed
()
output_result
=
tester
.
create_transfo_xl_lm_head
(
*
config_and_inputs
)
tester
.
check_transfo_xl_lm_head_output
(
output_result
)
@
classmethod
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
"""Creates a random int32 tensor of the shape within the vocab size."""
if
rng
is
None
:
rng
=
random
.
Random
()
total_dims
=
1
for
dim
in
shape
:
total_dims
*=
dim
values
=
[]
for
_
in
range
(
total_dims
):
values
.
append
(
rng
.
randint
(
0
,
vocab_size
-
1
))
return
torch
.
tensor
(
data
=
values
,
dtype
=
torch
.
long
).
view
(
shape
).
contiguous
()
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment