Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
32aabe8c
Commit
32aabe8c
authored
Sep 10, 2019
by
thomwolf
Browse files
WIP XLNet
parent
f851fb55
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1540 additions
and
68 deletions
+1540
-68
pytorch_transformers/__init__.py
pytorch_transformers/__init__.py
+2
-2
pytorch_transformers/configuration_utils.py
pytorch_transformers/configuration_utils.py
+1
-0
pytorch_transformers/modeling_tf_gpt2.py
pytorch_transformers/modeling_tf_gpt2.py
+10
-64
pytorch_transformers/modeling_tf_utils.py
pytorch_transformers/modeling_tf_utils.py
+63
-0
pytorch_transformers/modeling_tf_xlnet.py
pytorch_transformers/modeling_tf_xlnet.py
+1121
-0
pytorch_transformers/tests/modeling_tf_common_test.py
pytorch_transformers/tests/modeling_tf_common_test.py
+2
-2
pytorch_transformers/tests/modeling_tf_xlnet_test.py
pytorch_transformers/tests/modeling_tf_xlnet_test.py
+341
-0
No files found.
pytorch_transformers/__init__.py
View file @
32aabe8c
...
@@ -95,7 +95,7 @@ except (ImportError, AssertionError):
...
@@ -95,7 +95,7 @@ except (ImportError, AssertionError):
if
_tf_available
:
if
_tf_available
:
logger
.
info
(
"TensorFlow version {} available."
.
format
(
tf
.
__version__
))
logger
.
info
(
"TensorFlow version {} available."
.
format
(
tf
.
__version__
))
from
.modeling_tf_utils
import
TFPreTrainedModel
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFSharedEmbeddings
,
TFSequenceSummary
from
.modeling_tf_auto
import
(
TFAutoModel
,
TFAutoModelForSequenceClassification
,
TFAutoModelForQuestionAnswering
,
from
.modeling_tf_auto
import
(
TFAutoModel
,
TFAutoModelForSequenceClassification
,
TFAutoModelForQuestionAnswering
,
TFAutoModelWithLMHead
)
TFAutoModelWithLMHead
)
...
@@ -107,7 +107,7 @@ if _tf_available:
...
@@ -107,7 +107,7 @@ if _tf_available:
load_bert_pt_weights_in_tf2
,
load_bert_pt_weights_in_tf2
,
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_tf_gpt2
import
(
TFGPT2PreTrainedModel
,
TFGPT2MainLayer
,
TFGPT2Embeddings
,
from
.modeling_tf_gpt2
import
(
TFGPT2PreTrainedModel
,
TFGPT2MainLayer
,
TFGPT2Model
,
TFGPT2LMHeadModel
,
TFGPT2DoubleHeadsModel
,
TFGPT2Model
,
TFGPT2LMHeadModel
,
TFGPT2DoubleHeadsModel
,
load_gpt2_pt_weights_in_tf2
,
load_gpt2_pt_weights_in_tf2
,
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
)
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
)
...
...
pytorch_transformers/configuration_utils.py
View file @
32aabe8c
...
@@ -54,6 +54,7 @@ class PretrainedConfig(object):
...
@@ -54,6 +54,7 @@ class PretrainedConfig(object):
self
.
output_attentions
=
kwargs
.
pop
(
'output_attentions'
,
False
)
self
.
output_attentions
=
kwargs
.
pop
(
'output_attentions'
,
False
)
self
.
output_hidden_states
=
kwargs
.
pop
(
'output_hidden_states'
,
False
)
self
.
output_hidden_states
=
kwargs
.
pop
(
'output_hidden_states'
,
False
)
self
.
torchscript
=
kwargs
.
pop
(
'torchscript'
,
False
)
self
.
torchscript
=
kwargs
.
pop
(
'torchscript'
,
False
)
self
.
use_bfloat16
=
kwargs
.
pop
(
'use_bfloat16'
,
False
)
self
.
pruned_heads
=
kwargs
.
pop
(
'pruned_heads'
,
{})
self
.
pruned_heads
=
kwargs
.
pop
(
'pruned_heads'
,
{})
def
save_pretrained
(
self
,
save_directory
):
def
save_pretrained
(
self
,
save_directory
):
...
...
pytorch_transformers/modeling_tf_gpt2.py
View file @
32aabe8c
...
@@ -28,7 +28,8 @@ from io import open
...
@@ -28,7 +28,8 @@ from io import open
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFConv1D
,
TFSequenceSummary
,
shape_list
from
.modeling_tf_utils
import
(
TFPreTrainedModel
,
TFConv1D
,
TFSharedEmbeddings
,
TFSequenceSummary
,
shape_list
)
from
.configuration_gpt2
import
GPT2Config
from
.configuration_gpt2
import
GPT2Config
from
.file_utils
import
add_start_docstrings
from
.file_utils
import
add_start_docstrings
...
@@ -65,6 +66,7 @@ def load_gpt2_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path):
...
@@ -65,6 +66,7 @@ def load_gpt2_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path):
symbolic_weights
=
tf_model
.
trainable_weights
+
tf_model
.
non_trainable_weights
symbolic_weights
=
tf_model
.
trainable_weights
+
tf_model
.
non_trainable_weights
weight_value_tuples
=
[]
weight_value_tuples
=
[]
all_pytorch_weights
=
set
(
list
(
state_dict
.
keys
()))
for
symbolic_weight
in
symbolic_weights
:
for
symbolic_weight
in
symbolic_weights
:
name
=
symbolic_weight
.
name
name
=
symbolic_weight
.
name
name
=
name
.
replace
(
':0'
,
''
)
name
=
name
.
replace
(
':0'
,
''
)
...
@@ -100,13 +102,13 @@ def load_gpt2_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path):
...
@@ -100,13 +102,13 @@ def load_gpt2_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path):
weight_value_tuples
.
append
((
symbolic_weight
,
array
))
weight_value_tuples
.
append
((
symbolic_weight
,
array
))
state_dict
.
pop
(
name
)
all_pytorch_weights
.
discard
(
name
)
K
.
batch_set_value
(
weight_value_tuples
)
K
.
batch_set_value
(
weight_value_tuples
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# Make sure restore ops are run
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# Make sure restore ops are run
assert
not
state_dict
,
"Weights not loaded: {}"
.
format
(
list
(
state_dict
.
keys
()
))
logger
.
info
(
"Weights or buffers not loaded from PyTorch model: {}"
.
format
(
all_pytorch_weights
))
return
tf_model
return
tf_model
...
@@ -267,65 +269,6 @@ class TFBlock(tf.keras.layers.Layer):
...
@@ -267,65 +269,6 @@ class TFBlock(tf.keras.layers.Layer):
outputs
=
[
x
]
+
output_attn
[
1
:]
outputs
=
[
x
]
+
output_attn
[
1
:]
return
outputs
# x, present, (attentions)
return
outputs
# x, present, (attentions)
class
TFGPT2Embeddings
(
tf
.
keras
.
layers
.
Layer
):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def
__init__
(
self
,
config
,
**
kwargs
):
super
(
TFGPT2Embeddings
,
self
).
__init__
(
**
kwargs
)
self
.
vocab_size
=
config
.
vocab_size
self
.
hidden_size
=
config
.
hidden_size
def
build
(
self
,
input_shape
):
"""Build shared word embedding layer
Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
self
.
weight
=
self
.
add_weight
(
"weight"
,
shape
=
[
self
.
vocab_size
,
self
.
hidden_size
],
initializer
=
tf
.
random_normal_initializer
(
mean
=
0.
,
stddev
=
self
.
hidden_size
**-
0.5
))
super
(
TFGPT2Embeddings
,
self
).
build
(
input_shape
)
def
call
(
self
,
inputs
,
mode
=
"embedding"
):
"""Get token embeddings of inputs.
Args:
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
mode: string, a valid value is one of "embedding" and "linear".
Returns:
outputs: (1) If mode == "embedding", output embedding tensor, float32 with
shape [batch_size, length, embedding_size]; (2) mode == "linear", output
linear tensor, float32 with shape [batch_size, length, vocab_size].
Raises:
ValueError: if mode is not valid.
Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
if
mode
==
"embedding"
:
return
self
.
_embedding
(
inputs
)
elif
mode
==
"linear"
:
return
self
.
_linear
(
inputs
)
else
:
raise
ValueError
(
"mode {} is not valid."
.
format
(
mode
))
def
_embedding
(
self
,
input_ids
):
"""Applies embedding based on inputs tensor."""
return
tf
.
gather
(
self
.
weight
,
input_ids
)
def
_linear
(
self
,
inputs
):
"""Computes logits by running inputs through a linear layer.
Args:
inputs: A float32 tensor with shape [..., hidden_size]
Returns:
float32 tensor with shape [..., vocab_size].
"""
first_dims
=
shape_list
(
inputs
)[:
-
1
]
x
=
tf
.
reshape
(
inputs
,
[
-
1
,
self
.
hidden_size
])
logits
=
tf
.
matmul
(
x
,
self
.
weight
,
transpose_b
=
True
)
return
tf
.
reshape
(
logits
,
first_dims
+
[
self
.
vocab_size
])
class
TFGPT2MainLayer
(
tf
.
keras
.
layers
.
Layer
):
class
TFGPT2MainLayer
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
...
@@ -336,10 +279,13 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
...
@@ -336,10 +279,13 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
self
.
n_embd
=
config
.
n_embd
self
.
n_embd
=
config
.
n_embd
self
.
wte
=
TF
GPT2
Embeddings
(
config
,
name
=
'wte'
)
self
.
wte
=
TF
Shared
Embeddings
(
config
.
vocab_size
,
config
.
hidden_size
,
name
=
'wte'
)
self
.
wpe
=
tf
.
keras
.
layers
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
,
name
=
'wpe'
)
self
.
wpe
=
tf
.
keras
.
layers
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
,
name
=
'wpe'
)
self
.
drop
=
tf
.
keras
.
layers
.
Dropout
(
config
.
embd_pdrop
)
self
.
drop
=
tf
.
keras
.
layers
.
Dropout
(
config
.
embd_pdrop
)
self
.
h
=
[
TFBlock
(
config
.
n_ctx
,
config
,
scale
=
True
,
name
=
'h_{}'
.
format
(
i
))
for
i
in
range
(
config
.
n_layer
)]
self
.
h
=
[
TFBlock
(
config
.
n_ctx
,
config
,
scale
=
True
,
name
=
'h_{}'
.
format
(
i
))
for
i
in
range
(
config
.
n_layer
)]
self
.
ln_f
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
config
.
layer_norm_epsilon
,
name
=
'ln_f'
)
self
.
ln_f
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
config
.
layer_norm_epsilon
,
name
=
'ln_f'
)
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
...
...
pytorch_transformers/modeling_tf_utils.py
View file @
32aabe8c
...
@@ -288,6 +288,69 @@ class TFConv1D(tf.keras.layers.Layer):
...
@@ -288,6 +288,69 @@ class TFConv1D(tf.keras.layers.Layer):
return
x
return
x
class
TFSharedEmbeddings
(
tf
.
keras
.
layers
.
Layer
):
"""Construct shared token embeddings.
"""
def
__init__
(
self
,
vocab_size
,
hidden_size
,
initializer_range
=
None
,
**
kwargs
):
super
(
TFSharedEmbeddings
,
self
).
__init__
(
**
kwargs
)
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
initializer_range
=
initializer_range
def
build
(
self
,
input_shape
):
"""Build shared word embedding layer
Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
initializer_range
=
self
.
hidden_size
**-
0.5
if
self
.
initializer_range
is
None
else
self
.
initializer_range
self
.
weight
=
self
.
add_weight
(
"weight"
,
shape
=
[
self
.
vocab_size
,
self
.
hidden_size
],
initializer
=
tf
.
random_normal_initializer
(
mean
=
0.
,
stddev
=
initializer_range
))
super
(
TFSharedEmbeddings
,
self
).
build
(
input_shape
)
def
call
(
self
,
inputs
,
mode
=
"embedding"
):
"""Get token embeddings of inputs.
Args:
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
mode: string, a valid value is one of "embedding" and "linear".
Returns:
outputs: (1) If mode == "embedding", output embedding tensor, float32 with
shape [batch_size, length, embedding_size]; (2) mode == "linear", output
linear tensor, float32 with shape [batch_size, length, vocab_size].
Raises:
ValueError: if mode is not valid.
Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
if
mode
==
"embedding"
:
return
self
.
_embedding
(
inputs
)
elif
mode
==
"linear"
:
return
self
.
_linear
(
inputs
)
else
:
raise
ValueError
(
"mode {} is not valid."
.
format
(
mode
))
def
_embedding
(
self
,
input_ids
):
"""Applies embedding based on inputs tensor."""
return
tf
.
gather
(
self
.
weight
,
input_ids
)
def
_linear
(
self
,
inputs
):
"""Computes logits by running inputs through a linear layer.
Args:
inputs: A float32 tensor with shape [..., hidden_size]
Returns:
float32 tensor with shape [..., vocab_size].
"""
first_dims
=
shape_list
(
inputs
)[:
-
1
]
x
=
tf
.
reshape
(
inputs
,
[
-
1
,
self
.
hidden_size
])
logits
=
tf
.
matmul
(
x
,
self
.
weight
,
transpose_b
=
True
)
return
tf
.
reshape
(
logits
,
first_dims
+
[
self
.
vocab_size
])
class
TFSequenceSummary
(
tf
.
keras
.
layers
.
Layer
):
class
TFSequenceSummary
(
tf
.
keras
.
layers
.
Layer
):
r
""" Compute a single vector summary of a sequence hidden states according to various possibilities:
r
""" Compute a single vector summary of a sequence hidden states according to various possibilities:
Args of the config class:
Args of the config class:
...
...
pytorch_transformers/modeling_tf_xlnet.py
0 → 100644
View file @
32aabe8c
This diff is collapsed.
Click to expand it.
pytorch_transformers/tests/modeling_tf_common_test.py
View file @
32aabe8c
...
@@ -262,7 +262,7 @@ class TFCommonTestCases:
...
@@ -262,7 +262,7 @@ class TFCommonTestCases:
# self.assertEqual(len(params_tied_2), len(params_tied))
# self.assertEqual(len(params_tied_2), len(params_tied))
def
ids_tensor
(
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
def
ids_tensor
(
shape
,
vocab_size
,
rng
=
None
,
name
=
None
,
dtype
=
tf
.
int32
):
"""Creates a random int32 tensor of the shape within the vocab size."""
"""Creates a random int32 tensor of the shape within the vocab size."""
if
rng
is
None
:
if
rng
is
None
:
rng
=
random
.
Random
()
rng
=
random
.
Random
()
...
@@ -275,7 +275,7 @@ def ids_tensor(shape, vocab_size, rng=None, name=None):
...
@@ -275,7 +275,7 @@ def ids_tensor(shape, vocab_size, rng=None, name=None):
for
_
in
range
(
total_dims
):
for
_
in
range
(
total_dims
):
values
.
append
(
rng
.
randint
(
0
,
vocab_size
-
1
))
values
.
append
(
rng
.
randint
(
0
,
vocab_size
-
1
))
return
tf
.
constant
(
values
,
shape
=
shape
)
return
tf
.
constant
(
values
,
shape
=
shape
,
dtype
=
dtype
)
class
TFModelUtilsTest
(
unittest
.
TestCase
):
class
TFModelUtilsTest
(
unittest
.
TestCase
):
...
...
pytorch_transformers/tests/modeling_tf_xlnet_test.py
0 → 100644
View file @
32aabe8c
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
unittest
import
json
import
random
import
shutil
import
pytest
from
pytorch_transformers
import
XLNetConfig
,
is_tf_available
if
is_tf_available
():
import
tensorflow
as
tf
from
pytorch_transformers.modeling_tf_xlnet
import
(
TFXLNetModel
,
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
)
# XLNetLMHeadModel,
# XLNetForSequenceClassification, XLNetForQuestionAnswering)
else
:
pytestmark
=
pytest
.
mark
.
skip
(
"Require TensorFlow"
)
from
.modeling_tf_common_test
import
(
TFCommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
class
TFXLNetModelTest
(
TFCommonTestCases
.
TFCommonModelTester
):
all_model_classes
=
(
TFXLNetModel
,
)
if
is_tf_available
()
else
()
# all_model_classes=(TFXLNetModel, TFXLNetLMHeadModel,
# TFXLNetForSequenceClassification, TFXLNetForQuestionAnswering) if is_tf_available() else ()
test_pruning
=
False
class
TFXLNetModelTester
(
object
):
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
mem_len
=
10
,
clamp_len
=-
1
,
reuse_len
=
15
,
is_training
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
cutoffs
=
[
10
,
50
,
80
],
hidden_size
=
32
,
num_attention_heads
=
4
,
d_inner
=
128
,
num_hidden_layers
=
5
,
max_position_embeddings
=
10
,
type_sequence_label_size
=
2
,
untie_r
=
True
,
bi_data
=
False
,
same_length
=
False
,
initializer_range
=
0.05
,
seed
=
1
,
type_vocab_size
=
2
,
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
mem_len
=
mem_len
# self.key_len = seq_length + mem_len
self
.
clamp_len
=
clamp_len
self
.
reuse_len
=
reuse_len
self
.
is_training
=
is_training
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
cutoffs
=
cutoffs
self
.
hidden_size
=
hidden_size
self
.
num_attention_heads
=
num_attention_heads
self
.
d_inner
=
d_inner
self
.
num_hidden_layers
=
num_hidden_layers
self
.
max_position_embeddings
=
max_position_embeddings
self
.
bi_data
=
bi_data
self
.
untie_r
=
untie_r
self
.
same_length
=
same_length
self
.
initializer_range
=
initializer_range
self
.
seed
=
seed
self
.
type_vocab_size
=
type_vocab_size
self
.
type_sequence_label_size
=
type_sequence_label_size
def
prepare_config_and_inputs
(
self
):
input_ids_1
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_2
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
segment_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
2
,
dtype
=
tf
.
float32
)
input_ids_q
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
+
1
],
self
.
vocab_size
)
perm_mask
=
tf
.
zeros
((
self
.
batch_size
,
self
.
seq_length
+
1
,
self
.
seq_length
),
dtype
=
tf
.
float32
)
perm_mask_last
=
tf
.
ones
((
self
.
batch_size
,
self
.
seq_length
+
1
,
1
),
dtype
=
tf
.
float32
)
perm_mask
=
tf
.
concat
([
perm_mask
,
perm_mask_last
],
axis
=-
1
)
# perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
target_mapping
=
tf
.
zeros
((
self
.
batch_size
,
1
,
self
.
seq_length
),
dtype
=
torch
.
float32
)
target_mapping_last
=
tf
.
ones
((
self
.
batch_size
,
1
,
1
),
dtype
=
torch
.
float32
)
target_mapping
=
tf
.
concat
([
target_mapping
,
target_mapping_last
],
axis
=-
1
)
# target_mapping[:, 0, -1] = 1.0 # predict last token
sequence_labels
=
None
lm_labels
=
None
is_impossible_labels
=
None
if
self
.
use_labels
:
lm_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
sequence_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
is_impossible_labels
=
ids_tensor
([
self
.
batch_size
],
2
,
dtype
=
tf
.
float32
)
config
=
XLNetConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
d_model
=
self
.
hidden_size
,
n_head
=
self
.
num_attention_heads
,
d_inner
=
self
.
d_inner
,
n_layer
=
self
.
num_hidden_layers
,
untie_r
=
self
.
untie_r
,
max_position_embeddings
=
self
.
max_position_embeddings
,
mem_len
=
self
.
mem_len
,
clamp_len
=
self
.
clamp_len
,
same_length
=
self
.
same_length
,
reuse_len
=
self
.
reuse_len
,
bi_data
=
self
.
bi_data
,
initializer_range
=
self
.
initializer_range
,
num_labels
=
self
.
type_sequence_label_size
)
return
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
)
def
set_seed
(
self
):
random
.
seed
(
self
.
seed
)
tf
.
random
.
set_seed
(
self
.
seed
)
def
create_and_check_xlnet_base_model
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
model
=
TFXLNetModel
(
config
)
inputs
=
{
'input_ids'
:
input_ids
,
'input_mask'
:
input_mask
,
'token_type_ids'
:
token_type_ids
}
_
,
_
=
model
(
inputs
)
inputs
=
[
input_ids
,
input_mask
]
outputs
,
mems_1
=
model
(
inputs
)
result
=
{
"mems_1"
:
[
mem
.
numpy
()
for
m
in
mems_1
],
"outputs"
:
outputs
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"outputs"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
shape
)
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
pass
# model = XLNetLMHeadModel(config)
# model.eval()
# loss_1, all_logits_1, mems_1 = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels)
# loss_2, all_logits_2, mems_2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=mems_1)
# logits, _ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping)
# result = {
# "loss_1": loss_1,
# "mems_1": mems_1,
# "all_logits_1": all_logits_1,
# "loss_2": loss_2,
# "mems_2": mems_2,
# "all_logits_2": all_logits_2,
# }
# self.parent.assertListEqual(
# list(result["loss_1"].size()),
# [])
# self.parent.assertListEqual(
# list(result["all_logits_1"].size()),
# [self.batch_size, self.seq_length, self.vocab_size])
# self.parent.assertListEqual(
# list(list(mem.size()) for mem in result["mems_1"]),
# [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
# self.parent.assertListEqual(
# list(result["loss_2"].size()),
# [])
# self.parent.assertListEqual(
# list(result["all_logits_2"].size()),
# [self.batch_size, self.seq_length, self.vocab_size])
# self.parent.assertListEqual(
# list(list(mem.size()) for mem in result["mems_2"]),
# [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def
create_and_check_xlnet_qa
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
pass
# model = XLNetForQuestionAnswering(config)
# model.eval()
# outputs = model(input_ids_1)
# start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems = outputs
# outputs = model(input_ids_1, start_positions=sequence_labels,
# end_positions=sequence_labels,
# cls_index=sequence_labels,
# is_impossible=is_impossible_labels,
# p_mask=input_mask)
# outputs = model(input_ids_1, start_positions=sequence_labels,
# end_positions=sequence_labels,
# cls_index=sequence_labels,
# is_impossible=is_impossible_labels)
# total_loss, mems = outputs
# outputs = model(input_ids_1, start_positions=sequence_labels,
# end_positions=sequence_labels)
# total_loss, mems = outputs
# result = {
# "loss": total_loss,
# "start_top_log_probs": start_top_log_probs,
# "start_top_index": start_top_index,
# "end_top_log_probs": end_top_log_probs,
# "end_top_index": end_top_index,
# "cls_logits": cls_logits,
# "mems": mems,
# }
# self.parent.assertListEqual(
# list(result["loss"].size()),
# [])
# self.parent.assertListEqual(
# list(result["start_top_log_probs"].size()),
# [self.batch_size, model.config.start_n_top])
# self.parent.assertListEqual(
# list(result["start_top_index"].size()),
# [self.batch_size, model.config.start_n_top])
# self.parent.assertListEqual(
# list(result["end_top_log_probs"].size()),
# [self.batch_size, model.config.start_n_top * model.config.end_n_top])
# self.parent.assertListEqual(
# list(result["end_top_index"].size()),
# [self.batch_size, model.config.start_n_top * model.config.end_n_top])
# self.parent.assertListEqual(
# list(result["cls_logits"].size()),
# [self.batch_size])
# self.parent.assertListEqual(
# list(list(mem.size()) for mem in result["mems"]),
# [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def
create_and_check_xlnet_sequence_classif
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
pass
# model = XLNetForSequenceClassification(config)
# model.eval()
# logits, mems_1 = model(input_ids_1)
# loss, logits, mems_1 = model(input_ids_1, labels=sequence_labels)
# result = {
# "loss": loss,
# "mems_1": mems_1,
# "logits": logits,
# }
# self.parent.assertListEqual(
# list(result["loss"].size()),
# [])
# self.parent.assertListEqual(
# list(result["logits"].size()),
# [self.batch_size, self.type_sequence_label_size])
# self.parent.assertListEqual(
# list(list(mem.size()) for mem in result["mems_1"]),
# [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
)
=
config_and_inputs
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
return
config
,
inputs_dict
def
setUp
(
self
):
self
.
model_tester
=
TFXLNetModelTest
.
TFXLNetModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
XLNetConfig
,
d_inner
=
37
)
def
test_config
(
self
):
self
.
config_tester
.
run_common_tests
()
def
test_xlnet_base_model
(
self
):
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_xlnet_base_model
(
*
config_and_inputs
)
def
test_xlnet_lm_head
(
self
):
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_xlnet_lm_head
(
*
config_and_inputs
)
def
test_xlnet_sequence_classif
(
self
):
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_xlnet_sequence_classif
(
*
config_and_inputs
)
def
test_xlnet_qa
(
self
):
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_xlnet_qa
(
*
config_and_inputs
)
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_transformers_test/"
for
model_name
in
list
(
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
TFXLNetModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment