Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
65c49bb2
Commit
65c49bb2
authored
Sep 13, 2019
by
thomwolf
Browse files
adding TF 2.0 adaptive softmax with logits + loss outputs
parent
39c38b2e
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
869 additions
and
1105 deletions
+869
-1105
pytorch_transformers/modeling_tf_bert.py
pytorch_transformers/modeling_tf_bert.py
+2
-0
pytorch_transformers/modeling_tf_transfo_xl.py
pytorch_transformers/modeling_tf_transfo_xl.py
+344
-729
pytorch_transformers/modeling_tf_transfo_xl_utilities.py
pytorch_transformers/modeling_tf_transfo_xl_utilities.py
+279
-0
pytorch_transformers/modeling_tf_xlm.py
pytorch_transformers/modeling_tf_xlm.py
+2
-2
pytorch_transformers/modeling_transfo_xl.py
pytorch_transformers/modeling_transfo_xl.py
+20
-373
pytorch_transformers/tests/modeling_tf_bert_test.py
pytorch_transformers/tests/modeling_tf_bert_test.py
+5
-1
pytorch_transformers/tests/modeling_tf_transfo_xl_test.py
pytorch_transformers/tests/modeling_tf_transfo_xl_test.py
+217
-0
No files found.
pytorch_transformers/modeling_tf_bert.py
View file @
65c49bb2
...
@@ -455,6 +455,8 @@ class TFBertMainLayer(tf.keras.layers.Layer):
...
@@ -455,6 +455,8 @@ class TFBertMainLayer(tf.keras.layers.Layer):
"""
"""
raise
NotImplementedError
raise
NotImplementedError
# def call(self, input_ids, attention_mask=None, token_type_ids=None,
# position_ids=None, head_mask=None, training=False):
def
call
(
self
,
inputs
,
training
=
False
):
def
call
(
self
,
inputs
,
training
=
False
):
if
not
isinstance
(
inputs
,
(
dict
,
tuple
,
list
)):
if
not
isinstance
(
inputs
,
(
dict
,
tuple
,
list
)):
input_ids
=
inputs
input_ids
=
inputs
...
...
pytorch_transformers/modeling_tf_transfo_xl.py
View file @
65c49bb2
This diff is collapsed.
Click to expand it.
pytorch_transformers/modeling_tf_transfo_xl_utilities.py
0 → 100644
View file @
65c49bb2
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Utilities for PyTorch Transformer XL model.
Directly adapted from https://github.com/kimiyoung/transformer-xl.
"""
from
collections
import
defaultdict
import
numpy
as
np
import
tensorflow
as
tf
from
.modeling_tf_utils
import
shape_list
class
TFAdaptiveSoftmaxMask
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
n_token
,
d_embed
,
d_proj
,
cutoffs
,
div_val
=
1
,
keep_order
=
False
,
**
kwargs
):
super
(
TFAdaptiveSoftmaxMask
,
self
).
__init__
(
**
kwargs
)
self
.
n_token
=
n_token
self
.
d_embed
=
d_embed
self
.
d_proj
=
d_proj
self
.
cutoffs
=
cutoffs
+
[
n_token
]
self
.
cutoff_ends
=
[
0
]
+
self
.
cutoffs
self
.
div_val
=
div_val
self
.
shortlist_size
=
self
.
cutoffs
[
0
]
self
.
n_clusters
=
len
(
self
.
cutoffs
)
-
1
self
.
head_size
=
self
.
shortlist_size
+
self
.
n_clusters
self
.
keep_order
=
keep_order
self
.
out_layers
=
[]
self
.
out_projs
=
[]
def
build
(
self
,
input_shape
):
if
self
.
n_clusters
>
0
:
self
.
cluster_weight
=
self
.
add_weight
(
shape
=
(
self
.
n_clusters
,
self
.
d_embed
),
initializer
=
'zeros'
,
trainable
=
True
,
name
=
'cluster_weight'
)
self
.
cluster_bias
=
self
.
add_weight
(
shape
=
(
self
.
n_clusters
,),
initializer
=
'zeros'
,
trainable
=
True
,
name
=
'cluster_bias'
)
if
self
.
div_val
==
1
:
for
i
in
range
(
len
(
self
.
cutoffs
)):
if
self
.
d_proj
!=
self
.
d_embed
:
weight
=
self
.
add_weight
(
shape
=
(
self
.
d_embed
,
self
.
d_proj
),
initializer
=
'zeros'
,
trainable
=
True
,
name
=
'out_projs_._{}'
.
format
(
i
))
self
.
out_projs
.
append
(
weight
)
else
:
self
.
out_projs
.
append
(
None
)
weight
=
self
.
add_weight
(
shape
=
(
self
.
n_token
,
self
.
d_embed
,),
initializer
=
'zeros'
,
trainable
=
True
,
name
=
'out_layers_._{}_._weight'
.
format
(
i
))
bias
=
self
.
add_weight
(
shape
=
(
self
.
n_token
,),
initializer
=
'zeros'
,
trainable
=
True
,
name
=
'out_layers_._{}_._bias'
.
format
(
i
))
self
.
out_layers
.
append
((
weight
,
bias
))
else
:
for
i
in
range
(
len
(
self
.
cutoffs
)):
l_idx
,
r_idx
=
self
.
cutoff_ends
[
i
],
self
.
cutoff_ends
[
i
+
1
]
d_emb_i
=
self
.
d_embed
//
(
self
.
div_val
**
i
)
weight
=
self
.
add_weight
(
shape
=
(
d_emb_i
,
self
.
d_proj
),
initializer
=
'zeros'
,
trainable
=
True
,
name
=
'out_projs_._{}'
.
format
(
i
))
self
.
out_projs
.
append
(
weight
)
weight
=
self
.
add_weight
(
shape
=
(
r_idx
-
l_idx
,
d_emb_i
,),
initializer
=
'zeros'
,
trainable
=
True
,
name
=
'out_layers_._{}_._weight'
.
format
(
i
))
bias
=
self
.
add_weight
(
shape
=
(
r_idx
-
l_idx
,),
initializer
=
'zeros'
,
trainable
=
True
,
name
=
'out_layers_._{}_._bias'
.
format
(
i
))
self
.
out_layers
.
append
((
weight
,
bias
))
super
(
TFAdaptiveSoftmaxMask
,
self
).
build
(
input_shape
)
@
staticmethod
def
_logit
(
x
,
W
,
b
,
proj
=
None
):
y
=
x
if
proj
is
not
None
:
y
=
tf
.
einsum
(
'ibd,ed->ibe'
,
y
,
proj
)
return
tf
.
einsum
(
'ibd,nd->ibn'
,
y
,
W
)
+
b
@
staticmethod
def
_gather_logprob
(
logprob
,
target
):
lp_size
=
tf
.
shape
(
logprob
)
r
=
tf
.
range
(
lp_size
[
0
])
idx
=
tf
.
stack
([
r
,
target
],
1
)
return
tf
.
gather_nd
(
logprob
,
idx
)
def
call
(
self
,
inputs
,
return_mean
=
True
,
training
=
False
):
hidden
,
target
=
inputs
head_logprob
=
0
if
self
.
n_clusters
==
0
:
softmax_b
=
tf
.
get_variable
(
'bias'
,
[
n_token
],
initializer
=
tf
.
zeros_initializer
())
output
=
self
.
_logit
(
hidden
,
self
.
out_layers
[
0
][
0
],
self
.
out_layers
[
0
][
1
],
self
.
out_projs
[
0
])
if
target
is
not
None
:
loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
target
,
logits
=
output
)
out
=
tf
.
nn
.
log_softmax
(
output
,
axis
=-
1
)
else
:
hidden_sizes
=
shape_list
(
hidden
)
out
=
[]
loss
=
tf
.
zeros
(
hidden_sizes
[:
2
],
dtype
=
tf
.
float32
)
for
i
in
range
(
len
(
self
.
cutoffs
)):
l_idx
,
r_idx
=
self
.
cutoff_ends
[
i
],
self
.
cutoff_ends
[
i
+
1
]
if
target
is
not
None
:
mask
=
(
target
>=
l_idx
)
&
(
target
<
r_idx
)
mask_idx
=
tf
.
where
(
mask
)
cur_target
=
tf
.
boolean_mask
(
target
,
mask
)
-
l_idx
if
self
.
div_val
==
1
:
cur_W
=
self
.
out_layers
[
0
][
0
][
l_idx
:
r_idx
]
cur_b
=
self
.
out_layers
[
0
][
1
][
l_idx
:
r_idx
]
else
:
cur_W
=
self
.
out_layers
[
i
][
0
]
cur_b
=
self
.
out_layers
[
i
][
1
]
if
i
==
0
:
cur_W
=
tf
.
concat
([
cur_W
,
self
.
cluster_weight
],
0
)
cur_b
=
tf
.
concat
([
cur_b
,
self
.
cluster_bias
],
0
)
head_logit
=
self
.
_logit
(
hidden
,
cur_W
,
cur_b
,
self
.
out_projs
[
0
])
head_logprob
=
tf
.
nn
.
log_softmax
(
head_logit
)
out
.
append
(
head_logprob
[...,
:
self
.
cutoffs
[
0
]])
if
target
is
not
None
:
cur_head_logprob
=
tf
.
boolean_mask
(
head_logprob
,
mask
)
cur_logprob
=
self
.
_gather_logprob
(
cur_head_logprob
,
cur_target
)
else
:
tail_logit
=
self
.
_logit
(
hidden
,
cur_W
,
cur_b
,
self
.
out_projs
[
i
])
tail_logprob
=
tf
.
nn
.
log_softmax
(
tail_logit
)
cluster_prob_idx
=
self
.
cutoffs
[
0
]
+
i
-
1
# No probability for the head cluster
logprob_i
=
head_logprob
[...,
cluster_prob_idx
,
None
]
+
tail_logprob
out
.
append
(
logprob_i
)
if
target
is
not
None
:
cur_head_logprob
=
tf
.
boolean_mask
(
head_logprob
,
mask
)
cur_tail_logprob
=
tf
.
boolean_mask
(
tail_logprob
,
mask
)
cur_logprob
=
self
.
_gather_logprob
(
cur_tail_logprob
,
cur_target
)
cur_logprob
+=
cur_head_logprob
[:,
self
.
cutoff_ends
[
1
]
+
i
-
1
]
if
target
is
not
None
:
loss
+=
tf
.
scatter_nd
(
mask_idx
,
-
cur_logprob
,
tf
.
cast
(
tf
.
shape
(
loss
),
dtype
=
tf
.
int64
))
out
=
tf
.
concat
(
out
,
axis
=-
1
)
if
target
is
not
None
:
if
return_mean
:
loss
=
tf
.
reduce_mean
(
loss
)
# Add the training-time loss value to the layer using `self.add_loss()`.
self
.
add_loss
(
loss
)
# Log the loss as a metric (we could log arbitrary metrics,
# including different metrics for training and inference.
self
.
add_metric
(
loss
,
name
=
self
.
name
,
aggregation
=
'mean'
if
return_mean
else
''
)
return
out
def
mul_adaptive_logsoftmax
(
hidden
,
target
,
n_token
,
d_embed
,
d_proj
,
cutoffs
,
params
,
tie_projs
,
initializer
=
None
,
proj_initializer
=
None
,
div_val
=
1
,
perms
=
None
,
proj_same_dim
=
True
,
scope
=
'adaptive_softmax'
,
**
kwargs
):
def
_logit
(
x
,
W
,
b
,
proj
):
y
=
x
if
x
.
shape
.
ndims
==
3
:
if
proj
is
not
None
:
y
=
tf
.
einsum
(
'ibd,ed->ibe'
,
y
,
proj
)
return
tf
.
einsum
(
'ibd,nd->ibn'
,
y
,
W
)
+
b
else
:
if
proj
is
not
None
:
y
=
tf
.
einsum
(
'id,ed->ie'
,
y
,
proj
)
return
tf
.
einsum
(
'id,nd->in'
,
y
,
W
)
+
b
params_W
,
params_projs
=
params
[
0
],
params
[
1
]
with
tf
.
variable_scope
(
scope
):
if
len
(
cutoffs
)
==
0
:
softmax_b
=
tf
.
get_variable
(
'bias'
,
[
n_token
],
initializer
=
tf
.
zeros_initializer
())
output
=
_logit
(
hidden
,
params_W
,
softmax_b
,
params_projs
)
nll
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
target
,
logits
=
output
)
nll
=
tf
.
reduce_mean
(
nll
)
else
:
total_loss
,
total_cnt
=
0
,
0
cutoff_ends
=
[
0
]
+
cutoffs
+
[
n_token
]
for
i
in
range
(
len
(
cutoff_ends
)
-
1
):
with
tf
.
variable_scope
(
'cutoff_{}'
.
format
(
i
)):
l_idx
,
r_idx
=
cutoff_ends
[
i
],
cutoff_ends
[
i
+
1
]
cur_d_embed
=
d_embed
//
(
div_val
**
i
)
if
div_val
==
1
:
cur_W
=
params_W
[
l_idx
:
r_idx
]
else
:
cur_W
=
params_W
[
i
]
cur_b
=
tf
.
get_variable
(
'b'
,
[
r_idx
-
l_idx
],
initializer
=
tf
.
zeros_initializer
())
if
tie_projs
[
i
]:
if
div_val
==
1
:
cur_proj
=
params_projs
else
:
cur_proj
=
params_projs
[
i
]
else
:
if
(
div_val
==
1
or
not
proj_same_dim
)
and
d_proj
==
cur_d_embed
:
cur_proj
=
None
else
:
cur_proj
=
tf
.
get_variable
(
'proj'
,
[
cur_d_embed
,
d_proj
],
initializer
=
proj_initializer
)
if
i
==
0
:
cluster_W
=
tf
.
get_variable
(
'cluster_W'
,
[
len
(
cutoffs
),
d_embed
],
initializer
=
tf
.
zeros_initializer
())
cluster_b
=
tf
.
get_variable
(
'cluster_b'
,
[
len
(
cutoffs
)],
initializer
=
tf
.
zeros_initializer
())
cur_W
=
tf
.
concat
([
cur_W
,
cluster_W
],
0
)
cur_b
=
tf
.
concat
([
cur_b
,
cluster_b
],
0
)
head_logit
=
_logit
(
hidden
,
cur_W
,
cur_b
,
cur_proj
)
head_target
=
kwargs
.
get
(
"head_target"
)
head_nll
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
head_target
,
logits
=
head_logit
)
masked_loss
=
head_nll
*
perms
[
i
]
total_loss
+=
tf
.
reduce_sum
(
masked_loss
)
total_cnt
+=
tf
.
reduce_sum
(
perms
[
i
])
# head_logprob = tf.nn.log_softmax(head_logit)
# final_logprob = head_logprob * perms[i][:, :, None]
# final_target = tf.one_hot(target, tf.shape(head_logprob)[2])
# total_loss -= tf.einsum('ibn,ibn->', final_logprob, final_target)
# total_cnt += tf.reduce_sum(perms[i])
else
:
cur_head_nll
=
tf
.
einsum
(
'ib,ibk->k'
,
head_nll
,
perms
[
i
])
cur_hidden
=
tf
.
einsum
(
'ibd,ibk->kd'
,
hidden
,
perms
[
i
])
tail_logit
=
_logit
(
cur_hidden
,
cur_W
,
cur_b
,
cur_proj
)
tail_target
=
tf
.
einsum
(
'ib,ibk->k'
,
tf
.
to_float
(
target
-
l_idx
),
perms
[
i
])
tail_nll
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
tf
.
to_int32
(
tail_target
),
logits
=
tail_logit
)
sum_nll
=
cur_head_nll
+
tail_nll
mask
=
tf
.
reduce_sum
(
perms
[
i
],
[
0
,
1
])
masked_loss
=
sum_nll
*
mask
total_loss
+=
tf
.
reduce_sum
(
masked_loss
)
total_cnt
+=
tf
.
reduce_sum
(
mask
)
nll
=
total_loss
/
total_cnt
return
nll
\ No newline at end of file
pytorch_transformers/modeling_tf_xlm.py
View file @
65c49bb2
...
@@ -261,8 +261,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
...
@@ -261,8 +261,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
self
.
ffns
=
[]
self
.
ffns
=
[]
self
.
layer_norm2
=
[]
self
.
layer_norm2
=
[]
# if self.is_decoder:
# if self.is_decoder:
# self.layer_norm15 =
tf.keras.layers.LayerList()
# self.layer_norm15 =
[]
# self.encoder_attn =
tf.keras.layers.LayerList()
# self.encoder_attn =
[]
for
i
in
range
(
self
.
n_layers
):
for
i
in
range
(
self
.
n_layers
):
self
.
attentions
.
append
(
TFMultiHeadAttention
(
self
.
n_heads
,
self
.
dim
,
config
=
config
,
name
=
'attentions_._{}'
.
format
(
i
)))
self
.
attentions
.
append
(
TFMultiHeadAttention
(
self
.
n_heads
,
self
.
dim
,
config
=
config
,
name
=
'attentions_._{}'
.
format
(
i
)))
...
...
pytorch_transformers/modeling_transfo_xl.py
View file @
65c49bb2
This diff is collapsed.
Click to expand it.
pytorch_transformers/tests/modeling_tf_bert_test.py
View file @
65c49bb2
...
@@ -131,10 +131,14 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -131,10 +131,14 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
def
create_and_check_bert_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_and_check_bert_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
TFBertModel
(
config
=
config
)
model
=
TFBertModel
(
config
=
config
)
# inputs = {'input_ids': input_ids,
# 'attention_mask': input_mask,
# 'token_type_ids': token_type_ids}
# sequence_output, pooled_output = model(**inputs)
inputs
=
{
'input_ids'
:
input_ids
,
inputs
=
{
'input_ids'
:
input_ids
,
'attention_mask'
:
input_mask
,
'attention_mask'
:
input_mask
,
'token_type_ids'
:
token_type_ids
}
'token_type_ids'
:
token_type_ids
}
sequence_output
,
pooled_output
=
model
(
inputs
)
sequence_output
,
pooled_output
=
model
(
input
_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_id
s
)
inputs
=
[
input_ids
,
input_mask
]
inputs
=
[
input_ids
,
input_mask
]
sequence_output
,
pooled_output
=
model
(
inputs
)
sequence_output
,
pooled_output
=
model
(
inputs
)
...
...
pytorch_transformers/tests/modeling_tf_transfo_xl_test.py
0 → 100644
View file @
65c49bb2
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
unittest
import
random
import
shutil
import
pytest
from
.modeling_tf_common_test
import
(
TFCommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
from
pytorch_transformers
import
TransfoXLConfig
,
is_tf_available
if
is_tf_available
():
import
tensorflow
as
tf
from
pytorch_transformers.modeling_tf_transfo_xl
import
(
TFTransfoXLModel
,
TFTransfoXLLMHeadModel
,
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
)
else
:
pytestmark
=
pytest
.
mark
.
skip
(
"Require TensorFlow"
)
class
TFTransfoXLModelTest
(
TFCommonTestCases
.
TFCommonModelTester
):
all_model_classes
=
(
TFTransfoXLModel
,
TFTransfoXLLMHeadModel
)
if
is_tf_available
()
else
()
test_pruning
=
False
test_torchscript
=
False
test_resize_embeddings
=
False
class
TFTransfoXLModelTester
(
object
):
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
mem_len
=
30
,
clamp_len
=
15
,
is_training
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
cutoffs
=
[
10
,
50
,
80
],
hidden_size
=
32
,
d_embed
=
32
,
num_attention_heads
=
4
,
d_head
=
8
,
d_inner
=
128
,
div_val
=
2
,
num_hidden_layers
=
5
,
scope
=
None
,
seed
=
1
,
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
mem_len
=
mem_len
self
.
key_len
=
seq_length
+
mem_len
self
.
clamp_len
=
clamp_len
self
.
is_training
=
is_training
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
cutoffs
=
cutoffs
self
.
hidden_size
=
hidden_size
self
.
d_embed
=
d_embed
self
.
num_attention_heads
=
num_attention_heads
self
.
d_head
=
d_head
self
.
d_inner
=
d_inner
self
.
div_val
=
div_val
self
.
num_hidden_layers
=
num_hidden_layers
self
.
scope
=
scope
self
.
seed
=
seed
def
prepare_config_and_inputs
(
self
):
input_ids_1
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_2
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
lm_labels
=
None
if
self
.
use_labels
:
lm_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
config
=
TransfoXLConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
mem_len
=
self
.
mem_len
,
clamp_len
=
self
.
clamp_len
,
cutoffs
=
self
.
cutoffs
,
d_model
=
self
.
hidden_size
,
d_embed
=
self
.
d_embed
,
n_head
=
self
.
num_attention_heads
,
d_head
=
self
.
d_head
,
d_inner
=
self
.
d_inner
,
div_val
=
self
.
div_val
,
n_layer
=
self
.
num_hidden_layers
)
return
(
config
,
input_ids_1
,
input_ids_2
,
lm_labels
)
def
set_seed
(
self
):
random
.
seed
(
self
.
seed
)
tf
.
random
.
set_seed
(
self
.
seed
)
def
create_and_check_transfo_xl_model
(
self
,
config
,
input_ids_1
,
input_ids_2
,
lm_labels
):
model
=
TFTransfoXLModel
(
config
)
hidden_states_1
,
mems_1
=
model
(
input_ids_1
)
inputs
=
{
'input_ids'
:
input_ids_2
,
'mems'
:
mems_1
}
hidden_states_2
,
mems_2
=
model
(
inputs
)
result
=
{
"hidden_states_1"
:
hidden_states_1
.
numpy
(),
"mems_1"
:
[
mem
.
numpy
()
for
mem
in
mems_1
],
"hidden_states_2"
:
hidden_states_2
.
numpy
(),
"mems_2"
:
[
mem
.
numpy
()
for
mem
in
mems_2
],
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"hidden_states_1"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"hidden_states_2"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
shape
)
for
mem
in
result
[
"mems_1"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
shape
)
for
mem
in
result
[
"mems_2"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_transfo_xl_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
lm_labels
):
model
=
TFTransfoXLLMHeadModel
(
config
)
lm_logits_1
,
mems_1
=
model
(
input_ids_1
)
inputs
=
{
'input_ids'
:
input_ids_1
,
'labels'
:
lm_labels
}
_
,
mems_1
=
model
(
inputs
)
lm_logits_2
,
mems_2
=
model
([
input_ids_2
,
mems_1
])
inputs
=
{
'input_ids'
:
input_ids_1
,
'mems'
:
mems_1
,
'labels'
:
lm_labels
}
_
,
mems_2
=
model
(
inputs
)
result
=
{
"mems_1"
:
[
mem
.
numpy
()
for
mem
in
mems_1
],
"lm_logits_1"
:
lm_logits_1
.
numpy
(),
"mems_2"
:
[
mem
.
numpy
()
for
mem
in
mems_2
],
"lm_logits_2"
:
lm_logits_2
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits_1"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
shape
)
for
mem
in
result
[
"mems_1"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits_2"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
shape
)
for
mem
in
result
[
"mems_2"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids_1
,
input_ids_2
,
lm_labels
)
=
config_and_inputs
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
return
config
,
inputs_dict
def
setUp
(
self
):
self
.
model_tester
=
TFTransfoXLModelTest
.
TFTransfoXLModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
TransfoXLConfig
,
d_embed
=
37
)
def
test_config
(
self
):
self
.
config_tester
.
run_common_tests
()
def
test_transfo_xl_model
(
self
):
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_transfo_xl_model
(
*
config_and_inputs
)
def
test_transfo_xl_lm_head
(
self
):
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_transfo_xl_lm_head
(
*
config_and_inputs
)
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_transformers_test/"
for
model_name
in
list
(
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
TFTransfoXLModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment