Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
32aabe8c
Commit
32aabe8c
authored
Sep 10, 2019
by
thomwolf
Browse files
WIP XLNet
parent
f851fb55
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1540 additions
and
68 deletions
+1540
-68
pytorch_transformers/__init__.py
pytorch_transformers/__init__.py
+2
-2
pytorch_transformers/configuration_utils.py
pytorch_transformers/configuration_utils.py
+1
-0
pytorch_transformers/modeling_tf_gpt2.py
pytorch_transformers/modeling_tf_gpt2.py
+10
-64
pytorch_transformers/modeling_tf_utils.py
pytorch_transformers/modeling_tf_utils.py
+63
-0
pytorch_transformers/modeling_tf_xlnet.py
pytorch_transformers/modeling_tf_xlnet.py
+1121
-0
pytorch_transformers/tests/modeling_tf_common_test.py
pytorch_transformers/tests/modeling_tf_common_test.py
+2
-2
pytorch_transformers/tests/modeling_tf_xlnet_test.py
pytorch_transformers/tests/modeling_tf_xlnet_test.py
+341
-0
No files found.
pytorch_transformers/__init__.py
View file @
32aabe8c
...
@@ -95,7 +95,7 @@ except (ImportError, AssertionError):
...
@@ -95,7 +95,7 @@ except (ImportError, AssertionError):
if
_tf_available
:
if
_tf_available
:
logger
.
info
(
"TensorFlow version {} available."
.
format
(
tf
.
__version__
))
logger
.
info
(
"TensorFlow version {} available."
.
format
(
tf
.
__version__
))
from
.modeling_tf_utils
import
TFPreTrainedModel
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFSharedEmbeddings
,
TFSequenceSummary
from
.modeling_tf_auto
import
(
TFAutoModel
,
TFAutoModelForSequenceClassification
,
TFAutoModelForQuestionAnswering
,
from
.modeling_tf_auto
import
(
TFAutoModel
,
TFAutoModelForSequenceClassification
,
TFAutoModelForQuestionAnswering
,
TFAutoModelWithLMHead
)
TFAutoModelWithLMHead
)
...
@@ -107,7 +107,7 @@ if _tf_available:
...
@@ -107,7 +107,7 @@ if _tf_available:
load_bert_pt_weights_in_tf2
,
load_bert_pt_weights_in_tf2
,
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_tf_gpt2
import
(
TFGPT2PreTrainedModel
,
TFGPT2MainLayer
,
TFGPT2Embeddings
,
from
.modeling_tf_gpt2
import
(
TFGPT2PreTrainedModel
,
TFGPT2MainLayer
,
TFGPT2Model
,
TFGPT2LMHeadModel
,
TFGPT2DoubleHeadsModel
,
TFGPT2Model
,
TFGPT2LMHeadModel
,
TFGPT2DoubleHeadsModel
,
load_gpt2_pt_weights_in_tf2
,
load_gpt2_pt_weights_in_tf2
,
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
)
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
)
...
...
pytorch_transformers/configuration_utils.py
View file @
32aabe8c
...
@@ -54,6 +54,7 @@ class PretrainedConfig(object):
...
@@ -54,6 +54,7 @@ class PretrainedConfig(object):
self
.
output_attentions
=
kwargs
.
pop
(
'output_attentions'
,
False
)
self
.
output_attentions
=
kwargs
.
pop
(
'output_attentions'
,
False
)
self
.
output_hidden_states
=
kwargs
.
pop
(
'output_hidden_states'
,
False
)
self
.
output_hidden_states
=
kwargs
.
pop
(
'output_hidden_states'
,
False
)
self
.
torchscript
=
kwargs
.
pop
(
'torchscript'
,
False
)
self
.
torchscript
=
kwargs
.
pop
(
'torchscript'
,
False
)
self
.
use_bfloat16
=
kwargs
.
pop
(
'use_bfloat16'
,
False
)
self
.
pruned_heads
=
kwargs
.
pop
(
'pruned_heads'
,
{})
self
.
pruned_heads
=
kwargs
.
pop
(
'pruned_heads'
,
{})
def
save_pretrained
(
self
,
save_directory
):
def
save_pretrained
(
self
,
save_directory
):
...
...
pytorch_transformers/modeling_tf_gpt2.py
View file @
32aabe8c
...
@@ -28,7 +28,8 @@ from io import open
...
@@ -28,7 +28,8 @@ from io import open
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFConv1D
,
TFSequenceSummary
,
shape_list
from
.modeling_tf_utils
import
(
TFPreTrainedModel
,
TFConv1D
,
TFSharedEmbeddings
,
TFSequenceSummary
,
shape_list
)
from
.configuration_gpt2
import
GPT2Config
from
.configuration_gpt2
import
GPT2Config
from
.file_utils
import
add_start_docstrings
from
.file_utils
import
add_start_docstrings
...
@@ -65,6 +66,7 @@ def load_gpt2_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path):
...
@@ -65,6 +66,7 @@ def load_gpt2_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path):
symbolic_weights
=
tf_model
.
trainable_weights
+
tf_model
.
non_trainable_weights
symbolic_weights
=
tf_model
.
trainable_weights
+
tf_model
.
non_trainable_weights
weight_value_tuples
=
[]
weight_value_tuples
=
[]
all_pytorch_weights
=
set
(
list
(
state_dict
.
keys
()))
for
symbolic_weight
in
symbolic_weights
:
for
symbolic_weight
in
symbolic_weights
:
name
=
symbolic_weight
.
name
name
=
symbolic_weight
.
name
name
=
name
.
replace
(
':0'
,
''
)
name
=
name
.
replace
(
':0'
,
''
)
...
@@ -100,13 +102,13 @@ def load_gpt2_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path):
...
@@ -100,13 +102,13 @@ def load_gpt2_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path):
weight_value_tuples
.
append
((
symbolic_weight
,
array
))
weight_value_tuples
.
append
((
symbolic_weight
,
array
))
state_dict
.
pop
(
name
)
all_pytorch_weights
.
discard
(
name
)
K
.
batch_set_value
(
weight_value_tuples
)
K
.
batch_set_value
(
weight_value_tuples
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# Make sure restore ops are run
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# Make sure restore ops are run
assert
not
state_dict
,
"Weights not loaded: {}"
.
format
(
list
(
state_dict
.
keys
()
))
logger
.
info
(
"Weights or buffers not loaded from PyTorch model: {}"
.
format
(
all_pytorch_weights
))
return
tf_model
return
tf_model
...
@@ -267,65 +269,6 @@ class TFBlock(tf.keras.layers.Layer):
...
@@ -267,65 +269,6 @@ class TFBlock(tf.keras.layers.Layer):
outputs
=
[
x
]
+
output_attn
[
1
:]
outputs
=
[
x
]
+
output_attn
[
1
:]
return
outputs
# x, present, (attentions)
return
outputs
# x, present, (attentions)
class
TFGPT2Embeddings
(
tf
.
keras
.
layers
.
Layer
):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def
__init__
(
self
,
config
,
**
kwargs
):
super
(
TFGPT2Embeddings
,
self
).
__init__
(
**
kwargs
)
self
.
vocab_size
=
config
.
vocab_size
self
.
hidden_size
=
config
.
hidden_size
def
build
(
self
,
input_shape
):
"""Build shared word embedding layer
Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
self
.
weight
=
self
.
add_weight
(
"weight"
,
shape
=
[
self
.
vocab_size
,
self
.
hidden_size
],
initializer
=
tf
.
random_normal_initializer
(
mean
=
0.
,
stddev
=
self
.
hidden_size
**-
0.5
))
super
(
TFGPT2Embeddings
,
self
).
build
(
input_shape
)
def
call
(
self
,
inputs
,
mode
=
"embedding"
):
"""Get token embeddings of inputs.
Args:
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
mode: string, a valid value is one of "embedding" and "linear".
Returns:
outputs: (1) If mode == "embedding", output embedding tensor, float32 with
shape [batch_size, length, embedding_size]; (2) mode == "linear", output
linear tensor, float32 with shape [batch_size, length, vocab_size].
Raises:
ValueError: if mode is not valid.
Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
if
mode
==
"embedding"
:
return
self
.
_embedding
(
inputs
)
elif
mode
==
"linear"
:
return
self
.
_linear
(
inputs
)
else
:
raise
ValueError
(
"mode {} is not valid."
.
format
(
mode
))
def
_embedding
(
self
,
input_ids
):
"""Applies embedding based on inputs tensor."""
return
tf
.
gather
(
self
.
weight
,
input_ids
)
def
_linear
(
self
,
inputs
):
"""Computes logits by running inputs through a linear layer.
Args:
inputs: A float32 tensor with shape [..., hidden_size]
Returns:
float32 tensor with shape [..., vocab_size].
"""
first_dims
=
shape_list
(
inputs
)[:
-
1
]
x
=
tf
.
reshape
(
inputs
,
[
-
1
,
self
.
hidden_size
])
logits
=
tf
.
matmul
(
x
,
self
.
weight
,
transpose_b
=
True
)
return
tf
.
reshape
(
logits
,
first_dims
+
[
self
.
vocab_size
])
class
TFGPT2MainLayer
(
tf
.
keras
.
layers
.
Layer
):
class
TFGPT2MainLayer
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
...
@@ -336,10 +279,13 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
...
@@ -336,10 +279,13 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
self
.
n_embd
=
config
.
n_embd
self
.
n_embd
=
config
.
n_embd
self
.
wte
=
TF
GPT2
Embeddings
(
config
,
name
=
'wte'
)
self
.
wte
=
TF
Shared
Embeddings
(
config
.
vocab_size
,
config
.
hidden_size
,
name
=
'wte'
)
self
.
wpe
=
tf
.
keras
.
layers
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
,
name
=
'wpe'
)
self
.
wpe
=
tf
.
keras
.
layers
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
,
name
=
'wpe'
)
self
.
drop
=
tf
.
keras
.
layers
.
Dropout
(
config
.
embd_pdrop
)
self
.
drop
=
tf
.
keras
.
layers
.
Dropout
(
config
.
embd_pdrop
)
self
.
h
=
[
TFBlock
(
config
.
n_ctx
,
config
,
scale
=
True
,
name
=
'h_{}'
.
format
(
i
))
for
i
in
range
(
config
.
n_layer
)]
self
.
h
=
[
TFBlock
(
config
.
n_ctx
,
config
,
scale
=
True
,
name
=
'h_{}'
.
format
(
i
))
for
i
in
range
(
config
.
n_layer
)]
self
.
ln_f
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
config
.
layer_norm_epsilon
,
name
=
'ln_f'
)
self
.
ln_f
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
config
.
layer_norm_epsilon
,
name
=
'ln_f'
)
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
...
...
pytorch_transformers/modeling_tf_utils.py
View file @
32aabe8c
...
@@ -288,6 +288,69 @@ class TFConv1D(tf.keras.layers.Layer):
...
@@ -288,6 +288,69 @@ class TFConv1D(tf.keras.layers.Layer):
return
x
return
x
class
TFSharedEmbeddings
(
tf
.
keras
.
layers
.
Layer
):
"""Construct shared token embeddings.
"""
def
__init__
(
self
,
vocab_size
,
hidden_size
,
initializer_range
=
None
,
**
kwargs
):
super
(
TFSharedEmbeddings
,
self
).
__init__
(
**
kwargs
)
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
initializer_range
=
initializer_range
def
build
(
self
,
input_shape
):
"""Build shared word embedding layer
Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
initializer_range
=
self
.
hidden_size
**-
0.5
if
self
.
initializer_range
is
None
else
self
.
initializer_range
self
.
weight
=
self
.
add_weight
(
"weight"
,
shape
=
[
self
.
vocab_size
,
self
.
hidden_size
],
initializer
=
tf
.
random_normal_initializer
(
mean
=
0.
,
stddev
=
initializer_range
))
super
(
TFSharedEmbeddings
,
self
).
build
(
input_shape
)
def
call
(
self
,
inputs
,
mode
=
"embedding"
):
"""Get token embeddings of inputs.
Args:
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
mode: string, a valid value is one of "embedding" and "linear".
Returns:
outputs: (1) If mode == "embedding", output embedding tensor, float32 with
shape [batch_size, length, embedding_size]; (2) mode == "linear", output
linear tensor, float32 with shape [batch_size, length, vocab_size].
Raises:
ValueError: if mode is not valid.
Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
if
mode
==
"embedding"
:
return
self
.
_embedding
(
inputs
)
elif
mode
==
"linear"
:
return
self
.
_linear
(
inputs
)
else
:
raise
ValueError
(
"mode {} is not valid."
.
format
(
mode
))
def
_embedding
(
self
,
input_ids
):
"""Applies embedding based on inputs tensor."""
return
tf
.
gather
(
self
.
weight
,
input_ids
)
def
_linear
(
self
,
inputs
):
"""Computes logits by running inputs through a linear layer.
Args:
inputs: A float32 tensor with shape [..., hidden_size]
Returns:
float32 tensor with shape [..., vocab_size].
"""
first_dims
=
shape_list
(
inputs
)[:
-
1
]
x
=
tf
.
reshape
(
inputs
,
[
-
1
,
self
.
hidden_size
])
logits
=
tf
.
matmul
(
x
,
self
.
weight
,
transpose_b
=
True
)
return
tf
.
reshape
(
logits
,
first_dims
+
[
self
.
vocab_size
])
class
TFSequenceSummary
(
tf
.
keras
.
layers
.
Layer
):
class
TFSequenceSummary
(
tf
.
keras
.
layers
.
Layer
):
r
""" Compute a single vector summary of a sequence hidden states according to various possibilities:
r
""" Compute a single vector summary of a sequence hidden states according to various possibilities:
Args of the config class:
Args of the config class:
...
...
pytorch_transformers/modeling_tf_xlnet.py
0 → 100644
View file @
32aabe8c
This diff is collapsed.
Click to expand it.
pytorch_transformers/tests/modeling_tf_common_test.py
View file @
32aabe8c
...
@@ -262,7 +262,7 @@ class TFCommonTestCases:
...
@@ -262,7 +262,7 @@ class TFCommonTestCases:
# self.assertEqual(len(params_tied_2), len(params_tied))
# self.assertEqual(len(params_tied_2), len(params_tied))
def
ids_tensor
(
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
def
ids_tensor
(
shape
,
vocab_size
,
rng
=
None
,
name
=
None
,
dtype
=
tf
.
int32
):
"""Creates a random int32 tensor of the shape within the vocab size."""
"""Creates a random int32 tensor of the shape within the vocab size."""
if
rng
is
None
:
if
rng
is
None
:
rng
=
random
.
Random
()
rng
=
random
.
Random
()
...
@@ -275,7 +275,7 @@ def ids_tensor(shape, vocab_size, rng=None, name=None):
...
@@ -275,7 +275,7 @@ def ids_tensor(shape, vocab_size, rng=None, name=None):
for
_
in
range
(
total_dims
):
for
_
in
range
(
total_dims
):
values
.
append
(
rng
.
randint
(
0
,
vocab_size
-
1
))
values
.
append
(
rng
.
randint
(
0
,
vocab_size
-
1
))
return
tf
.
constant
(
values
,
shape
=
shape
)
return
tf
.
constant
(
values
,
shape
=
shape
,
dtype
=
dtype
)
class
TFModelUtilsTest
(
unittest
.
TestCase
):
class
TFModelUtilsTest
(
unittest
.
TestCase
):
...
...
pytorch_transformers/tests/modeling_tf_xlnet_test.py
0 → 100644
View file @
32aabe8c
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
unittest
import
json
import
random
import
shutil
import
pytest
from
pytorch_transformers
import
XLNetConfig
,
is_tf_available
if
is_tf_available
():
import
tensorflow
as
tf
from
pytorch_transformers.modeling_tf_xlnet
import
(
TFXLNetModel
,
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
)
# XLNetLMHeadModel,
# XLNetForSequenceClassification, XLNetForQuestionAnswering)
else
:
pytestmark
=
pytest
.
mark
.
skip
(
"Require TensorFlow"
)
from
.modeling_tf_common_test
import
(
TFCommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
class
TFXLNetModelTest
(
TFCommonTestCases
.
TFCommonModelTester
):
all_model_classes
=
(
TFXLNetModel
,
)
if
is_tf_available
()
else
()
# all_model_classes=(TFXLNetModel, TFXLNetLMHeadModel,
# TFXLNetForSequenceClassification, TFXLNetForQuestionAnswering) if is_tf_available() else ()
test_pruning
=
False
class
TFXLNetModelTester
(
object
):
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
mem_len
=
10
,
clamp_len
=-
1
,
reuse_len
=
15
,
is_training
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
cutoffs
=
[
10
,
50
,
80
],
hidden_size
=
32
,
num_attention_heads
=
4
,
d_inner
=
128
,
num_hidden_layers
=
5
,
max_position_embeddings
=
10
,
type_sequence_label_size
=
2
,
untie_r
=
True
,
bi_data
=
False
,
same_length
=
False
,
initializer_range
=
0.05
,
seed
=
1
,
type_vocab_size
=
2
,
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
mem_len
=
mem_len
# self.key_len = seq_length + mem_len
self
.
clamp_len
=
clamp_len
self
.
reuse_len
=
reuse_len
self
.
is_training
=
is_training
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
cutoffs
=
cutoffs
self
.
hidden_size
=
hidden_size
self
.
num_attention_heads
=
num_attention_heads
self
.
d_inner
=
d_inner
self
.
num_hidden_layers
=
num_hidden_layers
self
.
max_position_embeddings
=
max_position_embeddings
self
.
bi_data
=
bi_data
self
.
untie_r
=
untie_r
self
.
same_length
=
same_length
self
.
initializer_range
=
initializer_range
self
.
seed
=
seed
self
.
type_vocab_size
=
type_vocab_size
self
.
type_sequence_label_size
=
type_sequence_label_size
def
prepare_config_and_inputs
(
self
):
input_ids_1
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_2
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
segment_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
2
,
dtype
=
tf
.
float32
)
input_ids_q
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
+
1
],
self
.
vocab_size
)
perm_mask
=
tf
.
zeros
((
self
.
batch_size
,
self
.
seq_length
+
1
,
self
.
seq_length
),
dtype
=
tf
.
float32
)
perm_mask_last
=
tf
.
ones
((
self
.
batch_size
,
self
.
seq_length
+
1
,
1
),
dtype
=
tf
.
float32
)
perm_mask
=
tf
.
concat
([
perm_mask
,
perm_mask_last
],
axis
=-
1
)
# perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
target_mapping
=
tf
.
zeros
((
self
.
batch_size
,
1
,
self
.
seq_length
),
dtype
=
torch
.
float32
)
target_mapping_last
=
tf
.
ones
((
self
.
batch_size
,
1
,
1
),
dtype
=
torch
.
float32
)
target_mapping
=
tf
.
concat
([
target_mapping
,
target_mapping_last
],
axis
=-
1
)
# target_mapping[:, 0, -1] = 1.0 # predict last token
sequence_labels
=
None
lm_labels
=
None
is_impossible_labels
=
None
if
self
.
use_labels
:
lm_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
sequence_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
is_impossible_labels
=
ids_tensor
([
self
.
batch_size
],
2
,
dtype
=
tf
.
float32
)
config
=
XLNetConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
d_model
=
self
.
hidden_size
,
n_head
=
self
.
num_attention_heads
,
d_inner
=
self
.
d_inner
,
n_layer
=
self
.
num_hidden_layers
,
untie_r
=
self
.
untie_r
,
max_position_embeddings
=
self
.
max_position_embeddings
,
mem_len
=
self
.
mem_len
,
clamp_len
=
self
.
clamp_len
,
same_length
=
self
.
same_length
,
reuse_len
=
self
.
reuse_len
,
bi_data
=
self
.
bi_data
,
initializer_range
=
self
.
initializer_range
,
num_labels
=
self
.
type_sequence_label_size
)
return
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
)
def
set_seed
(
self
):
random
.
seed
(
self
.
seed
)
tf
.
random
.
set_seed
(
self
.
seed
)
def
create_and_check_xlnet_base_model
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
model
=
TFXLNetModel
(
config
)
inputs
=
{
'input_ids'
:
input_ids
,
'input_mask'
:
input_mask
,
'token_type_ids'
:
token_type_ids
}
_
,
_
=
model
(
inputs
)
inputs
=
[
input_ids
,
input_mask
]
outputs
,
mems_1
=
model
(
inputs
)
result
=
{
"mems_1"
:
[
mem
.
numpy
()
for
m
in
mems_1
],
"outputs"
:
outputs
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"outputs"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
shape
)
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
pass
# model = XLNetLMHeadModel(config)
# model.eval()
# loss_1, all_logits_1, mems_1 = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels)
# loss_2, all_logits_2, mems_2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=mems_1)
# logits, _ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping)
# result = {
# "loss_1": loss_1,
# "mems_1": mems_1,
# "all_logits_1": all_logits_1,
# "loss_2": loss_2,
# "mems_2": mems_2,
# "all_logits_2": all_logits_2,
# }
# self.parent.assertListEqual(
# list(result["loss_1"].size()),
# [])
# self.parent.assertListEqual(
# list(result["all_logits_1"].size()),
# [self.batch_size, self.seq_length, self.vocab_size])
# self.parent.assertListEqual(
# list(list(mem.size()) for mem in result["mems_1"]),
# [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
# self.parent.assertListEqual(
# list(result["loss_2"].size()),
# [])
# self.parent.assertListEqual(
# list(result["all_logits_2"].size()),
# [self.batch_size, self.seq_length, self.vocab_size])
# self.parent.assertListEqual(
# list(list(mem.size()) for mem in result["mems_2"]),
# [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def
create_and_check_xlnet_qa
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
pass
# model = XLNetForQuestionAnswering(config)
# model.eval()
# outputs = model(input_ids_1)
# start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems = outputs
# outputs = model(input_ids_1, start_positions=sequence_labels,
# end_positions=sequence_labels,
# cls_index=sequence_labels,
# is_impossible=is_impossible_labels,
# p_mask=input_mask)
# outputs = model(input_ids_1, start_positions=sequence_labels,
# end_positions=sequence_labels,
# cls_index=sequence_labels,
# is_impossible=is_impossible_labels)
# total_loss, mems = outputs
# outputs = model(input_ids_1, start_positions=sequence_labels,
# end_positions=sequence_labels)
# total_loss, mems = outputs
# result = {
# "loss": total_loss,
# "start_top_log_probs": start_top_log_probs,
# "start_top_index": start_top_index,
# "end_top_log_probs": end_top_log_probs,
# "end_top_index": end_top_index,
# "cls_logits": cls_logits,
# "mems": mems,
# }
# self.parent.assertListEqual(
# list(result["loss"].size()),
# [])
# self.parent.assertListEqual(
# list(result["start_top_log_probs"].size()),
# [self.batch_size, model.config.start_n_top])
# self.parent.assertListEqual(
# list(result["start_top_index"].size()),
# [self.batch_size, model.config.start_n_top])
# self.parent.assertListEqual(
# list(result["end_top_log_probs"].size()),
# [self.batch_size, model.config.start_n_top * model.config.end_n_top])
# self.parent.assertListEqual(
# list(result["end_top_index"].size()),
# [self.batch_size, model.config.start_n_top * model.config.end_n_top])
# self.parent.assertListEqual(
# list(result["cls_logits"].size()),
# [self.batch_size])
# self.parent.assertListEqual(
# list(list(mem.size()) for mem in result["mems"]),
# [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def
create_and_check_xlnet_sequence_classif
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
pass
# model = XLNetForSequenceClassification(config)
# model.eval()
# logits, mems_1 = model(input_ids_1)
# loss, logits, mems_1 = model(input_ids_1, labels=sequence_labels)
# result = {
# "loss": loss,
# "mems_1": mems_1,
# "logits": logits,
# }
# self.parent.assertListEqual(
# list(result["loss"].size()),
# [])
# self.parent.assertListEqual(
# list(result["logits"].size()),
# [self.batch_size, self.type_sequence_label_size])
# self.parent.assertListEqual(
# list(list(mem.size()) for mem in result["mems_1"]),
# [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
)
=
config_and_inputs
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
return
config
,
inputs_dict
def
setUp
(
self
):
self
.
model_tester
=
TFXLNetModelTest
.
TFXLNetModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
XLNetConfig
,
d_inner
=
37
)
def
test_config
(
self
):
self
.
config_tester
.
run_common_tests
()
def
test_xlnet_base_model
(
self
):
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_xlnet_base_model
(
*
config_and_inputs
)
def
test_xlnet_lm_head
(
self
):
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_xlnet_lm_head
(
*
config_and_inputs
)
def
test_xlnet_sequence_classif
(
self
):
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_xlnet_sequence_classif
(
*
config_and_inputs
)
def
test_xlnet_qa
(
self
):
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_xlnet_qa
(
*
config_and_inputs
)
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_transformers_test/"
for
model_name
in
list
(
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
TFXLNetModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment