Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c56d921d
Commit
c56d921d
authored
Oct 09, 2019
by
thomwolf
Browse files
adding TF 2.0 model
parent
45dc04f3
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
426 additions
and
209 deletions
+426
-209
transformers/configuration_ctrl.py
transformers/configuration_ctrl.py
+2
-2
transformers/modeling_ctrl.py
transformers/modeling_ctrl.py
+1
-1
transformers/modeling_roberta.py
transformers/modeling_roberta.py
+2
-1
transformers/modeling_tf_ctrl.py
transformers/modeling_tf_ctrl.py
+219
-204
transformers/tests/modeling_tf_ctrl_test.py
transformers/tests/modeling_tf_ctrl_test.py
+201
-0
transformers/tests/modeling_tf_gpt2_test.py
transformers/tests/modeling_tf_gpt2_test.py
+1
-1
No files found.
transformers/configuration_ctrl.py
View file @
c56d921d
...
@@ -53,8 +53,8 @@ class CTRLConfig(PretrainedConfig):
...
@@ -53,8 +53,8 @@ class CTRLConfig(PretrainedConfig):
def
__init__
(
def
__init__
(
self
,
self
,
vocab_size_or_config_json_file
=
246534
,
vocab_size_or_config_json_file
=
246534
,
n_positions
=
50000
,
n_positions
=
256
,
n_ctx
=
512
,
n_ctx
=
256
,
n_embd
=
1280
,
n_embd
=
1280
,
dff
=
8192
,
dff
=
8192
,
n_layer
=
48
,
n_layer
=
48
,
...
...
transformers/modeling_ctrl.py
View file @
c56d921d
...
@@ -351,7 +351,7 @@ class CTRLModel(CTRLPreTrainedModel):
...
@@ -351,7 +351,7 @@ class CTRLModel(CTRLPreTrainedModel):
x
=
self
.
w
(
input_ids
)
x
=
self
.
w
(
input_ids
)
# x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
# x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
seq_len
=
input_ids
.
shape
[
1
]
seq_len
=
input_ids
.
shape
[
-
1
]
mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
),
1
).
to
(
x
.
device
)
mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
),
1
).
to
(
x
.
device
)
x
*=
np
.
sqrt
(
self
.
d_model_size
)
x
*=
np
.
sqrt
(
self
.
d_model_size
)
...
...
transformers/modeling_roberta.py
View file @
c56d921d
...
@@ -172,7 +172,8 @@ class RobertaModel(BertModel):
...
@@ -172,7 +172,8 @@ class RobertaModel(BertModel):
if
input_ids
[:,
0
].
sum
().
item
()
!=
0
:
if
input_ids
[:,
0
].
sum
().
item
()
!=
0
:
logger
.
warning
(
"A sequence with no special tokens has been passed to the RoBERTa model. "
logger
.
warning
(
"A sequence with no special tokens has been passed to the RoBERTa model. "
"This model requires special tokens in order to work. "
"This model requires special tokens in order to work. "
"Please specify add_special_tokens=True in your encoding."
)
"Please specify add_special_tokens=True in your tokenize.encode()"
"or tokenizer.convert_tokens_to_ids()."
)
return
super
(
RobertaModel
,
self
).
forward
(
input_ids
,
return
super
(
RobertaModel
,
self
).
forward
(
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
...
...
transformers/modeling_tf_ctrl.py
View file @
c56d921d
...
@@ -25,7 +25,7 @@ import numpy as np
...
@@ -25,7 +25,7 @@ import numpy as np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
.configuration_ctrl
import
CTRLConfig
from
.configuration_ctrl
import
CTRLConfig
from
.modeling_tf_utils
import
TFPreTrainedModel
,
get_initializer
from
.modeling_tf_utils
import
TFPreTrainedModel
,
get_initializer
,
shape_list
,
TFSharedEmbeddings
from
.file_utils
import
add_start_docstrings
from
.file_utils
import
add_start_docstrings
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
...
@@ -33,12 +33,19 @@ logger = logging.getLogger(__name__)
...
@@ -33,12 +33,19 @@ logger = logging.getLogger(__name__)
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
=
{
"ctrl"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/ctrl-tf_model.h5"
}
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
=
{
"ctrl"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/ctrl-tf_model.h5"
}
def
load_ctrl_pt_weights_in_tf2
(
tf_model
,
pytorch_checkpoint_path
):
# build the network
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
tf_inputs
=
tf
.
constant
(
inputs_list
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
tf_inputs
)
def
angle_defn
(
pos
,
i
,
d_model_size
):
def
angle_defn
(
pos
,
i
,
d_model_size
):
angle_rates
=
1
/
np
.
power
(
10000
,
(
2
*
(
i
//
2
))
/
np
.
float32
(
d_model_size
))
angle_rates
=
1
/
np
.
power
(
10000
,
(
2
*
(
i
//
2
))
/
np
.
float32
(
d_model_size
))
return
pos
*
angle_rates
return
pos
*
angle_rates
def
positional_encoding
(
position
,
d_model_size
,
dtype
):
def
positional_encoding
(
position
,
d_model_size
):
# create the sinusoidal pattern for the positional encoding
# create the sinusoidal pattern for the positional encoding
angle_rads
=
angle_defn
(
np
.
arange
(
position
)[:,
np
.
newaxis
],
angle_rads
=
angle_defn
(
np
.
arange
(
position
)[:,
np
.
newaxis
],
np
.
arange
(
d_model_size
)[
np
.
newaxis
,
:],
np
.
arange
(
d_model_size
)[
np
.
newaxis
,
:],
...
@@ -47,14 +54,15 @@ def positional_encoding(position, d_model_size, dtype):
...
@@ -47,14 +54,15 @@ def positional_encoding(position, d_model_size, dtype):
sines
=
np
.
sin
(
angle_rads
[:,
0
::
2
])
sines
=
np
.
sin
(
angle_rads
[:,
0
::
2
])
cosines
=
np
.
cos
(
angle_rads
[:,
1
::
2
])
cosines
=
np
.
cos
(
angle_rads
[:,
1
::
2
])
pos_encoding
=
tf
.
cast
(
np
.
concatenate
([
sines
,
cosines
],
axis
=-
1
)[
np
.
newaxis
,
...],
dtype
=
tf
.
float32
)
# pos_encoding = tf.cast(np.concatenate([sines, cosines], axis=-1)[np.newaxis, ...], dtype=tf.float32)
pos_encoding
=
tf
.
cast
(
np
.
concatenate
([
sines
,
cosines
],
axis
=-
1
),
dtype
=
tf
.
float32
)
return
pos_encoding
return
pos_encoding
def
scaled_dot_product_attention
(
q
,
k
,
v
,
mask
,
attention_mask
=
None
,
head_mask
=
None
):
def
scaled_dot_product_attention
(
q
,
k
,
v
,
mask
,
attention_mask
=
None
,
head_mask
=
None
):
# calculate attention
# calculate attention
matmul_qk
=
tf
.
matmul
(
q
,
k
,
transpose_b
=
True
)
matmul_qk
=
tf
.
matmul
(
q
,
k
,
transpose_b
=
True
)
dk
=
tf
.
cast
(
tf
.
shape
(
k
)[
-
1
],
tf
.
float32
)
dk
=
tf
.
cast
(
shape
_list
(
k
)[
-
1
],
tf
.
float32
)
scaled_attention_logits
=
matmul_qk
/
tf
.
math
.
sqrt
(
dk
)
scaled_attention_logits
=
matmul_qk
/
tf
.
math
.
sqrt
(
dk
)
if
mask
is
not
None
:
if
mask
is
not
None
:
...
@@ -94,7 +102,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
...
@@ -94,7 +102,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
x
=
tf
.
reshape
(
x
,
(
batch_size
,
-
1
,
self
.
num_heads
,
self
.
depth
))
x
=
tf
.
reshape
(
x
,
(
batch_size
,
-
1
,
self
.
num_heads
,
self
.
depth
))
return
tf
.
transpose
(
x
,
perm
=
[
0
,
2
,
1
,
3
])
return
tf
.
transpose
(
x
,
perm
=
[
0
,
2
,
1
,
3
])
def
call
(
self
,
inputs
,
training
=
False
)
def
call
(
self
,
inputs
,
training
=
False
)
:
v
,
k
,
q
,
mask
,
layer_past
,
attention_mask
,
head_mask
=
inputs
v
,
k
,
q
,
mask
,
layer_past
,
attention_mask
,
head_mask
=
inputs
batch_size
=
q
.
shape
[
0
]
batch_size
=
q
.
shape
[
0
]
...
@@ -124,31 +132,34 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
...
@@ -124,31 +132,34 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
def
point_wise_feed_forward_network
(
d_model_size
,
dff
):
def
point_wise_feed_forward_network
(
d_model_size
,
dff
,
name
=
""
):
return
tf
.
keras
.
Sequential
([
tf
.
keras
.
layers
.
Dense
(
dff
,
activation
=
'relu'
),
return
tf
.
keras
.
Sequential
([
tf
.
keras
.
layers
.
Dense
(
d_model_size
)])
tf
.
keras
.
layers
.
Dense
(
dff
,
activation
=
'relu'
,
name
=
"0"
),
tf
.
keras
.
layers
.
Dense
(
d_model_size
,
name
=
"2"
)
],
name
=
"ffn"
)
class
TFEncoderLayer
(
tf
.
keras
.
layers
.
Layer
):
class
TFEncoderLayer
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
d_model_size
,
num_heads
,
dff
,
rate
=
0.1
,
output_attentions
=
False
,
**
kwargs
):
def
__init__
(
self
,
d_model_size
,
num_heads
,
dff
,
rate
=
0.1
,
layer_norm_epsilon
=
1e-6
,
output_attentions
=
False
,
**
kwargs
):
super
(
TFEncoderLayer
,
self
).
__init__
(
**
kwargs
)
super
(
TFEncoderLayer
,
self
).
__init__
(
**
kwargs
)
self
.
multi_head_attention
=
MultiHeadAttention
(
d_model_size
,
num_heads
,
output_attentions
)
self
.
multi_head_attention
=
TFMultiHeadAttention
(
d_model_size
,
self
.
ffn
=
point_wise_feed_forward_network
(
d_model_size
,
dff
)
num_heads
,
output_attentions
,
name
=
"multi_head_attention"
)
self
.
ffn
=
point_wise_feed_forward_network
(
d_model_size
,
dff
,
name
=
"ffn"
)
self
.
layernorm1
=
t
orch
.
nn
.
LayerNorm
(
d_model_size
,
eps
=
1e-6
)
self
.
layernorm1
=
t
f
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
layer_norm_epsilon
,
name
=
"layernorm1"
)
self
.
layernorm2
=
t
orch
.
nn
.
LayerNorm
(
d_model_size
,
eps
=
1e-6
)
self
.
layernorm2
=
t
f
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
layer_norm_epsilon
,
name
=
"layernorm2"
)
self
.
dropout1
=
t
orch
.
nn
.
Dropout
(
rate
)
self
.
dropout1
=
t
f
.
keras
.
layers
.
Dropout
(
rate
)
self
.
dropout2
=
t
orch
.
nn
.
Dropout
(
rate
)
self
.
dropout2
=
t
f
.
keras
.
layers
.
Dropout
(
rate
)
def
call
(
self
,
inputs
,
training
=
False
):
def
call
(
self
,
inputs
,
training
=
False
):
x
,
mask
,
layer_past
,
attention_mask
,
head_mask
=
inputs
x
,
mask
,
layer_past
,
attention_mask
,
head_mask
=
inputs
normed
=
self
.
layernorm1
(
x
)
normed
=
self
.
layernorm1
(
x
)
attn_outputs
=
self
.
multi_head_attention
(
normed
,
normed
,
normed
,
mask
,
attn_outputs
=
self
.
multi_head_attention
([
normed
,
normed
,
normed
,
mask
,
layer_past
,
layer_past
=
layer_past
,
attention_mask
,
head_mask
],
training
=
training
)
attention_mask
=
attention_mask
,
head_mask
=
head_mask
)
attn_output
=
attn_outputs
[
0
]
attn_output
=
attn_outputs
[
0
]
attn_output
=
self
.
dropout1
(
attn_output
,
training
=
training
)
attn_output
=
self
.
dropout1
(
attn_output
,
training
=
training
)
out1
=
x
+
attn_output
out1
=
x
+
attn_output
...
@@ -162,6 +173,152 @@ class TFEncoderLayer(tf.keras.layers.Layer):
...
@@ -162,6 +173,152 @@ class TFEncoderLayer(tf.keras.layers.Layer):
return
outputs
return
outputs
class
TFCTRLMainLayer
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
config
,
**
kwargs
):
super
(
TFCTRLMainLayer
,
self
).
__init__
(
**
kwargs
)
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
d_model_size
=
config
.
n_embd
self
.
num_layers
=
config
.
n_layer
self
.
pos_encoding
=
positional_encoding
(
config
.
n_positions
,
self
.
d_model_size
)
self
.
output_attentions
=
config
.
output_attentions
self
.
w
=
TFSharedEmbeddings
(
config
.
vocab_size
,
config
.
n_embd
,
initializer_range
=
config
.
initializer_range
,
name
=
"w"
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
embd_pdrop
)
self
.
h
=
[
TFEncoderLayer
(
config
.
n_embd
,
config
.
n_head
,
config
.
dff
,
config
.
resid_pdrop
,
config
.
layer_norm_epsilon
,
config
.
output_attentions
,
name
=
'h_._{}'
.
format
(
i
))
for
i
in
range
(
config
.
n_layer
)]
self
.
layernorm
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
config
.
layer_norm_epsilon
,
name
=
"layernorm"
)
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
raise
NotImplementedError
def
_prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
raise
NotImplementedError
def
call
(
self
,
inputs
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
training
=
False
):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
past
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
past
attention_mask
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
attention_mask
token_type_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
token_type_ids
position_ids
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
position_ids
head_mask
=
inputs
[
5
]
if
len
(
inputs
)
>
5
else
head_mask
assert
len
(
inputs
)
<=
6
,
"Too many inputs."
elif
isinstance
(
inputs
,
dict
):
input_ids
=
inputs
.
get
(
'input_ids'
)
past
=
inputs
.
get
(
'past'
,
past
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
attention_mask
)
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
token_type_ids
)
position_ids
=
inputs
.
get
(
'position_ids'
,
position_ids
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
assert
len
(
inputs
)
<=
6
,
"Too many inputs."
else
:
input_ids
=
inputs
input_shape
=
shape_list
(
input_ids
)
input_ids
=
tf
.
reshape
(
input_ids
,
[
-
1
,
input_shape
[
-
1
]])
if
past
is
None
:
past_length
=
0
past
=
[
None
]
*
len
(
self
.
h
)
else
:
past_length
=
shape_list
(
past
[
0
][
0
])[
-
2
]
if
position_ids
is
None
:
position_ids
=
tf
.
range
(
past_length
,
shape_list
(
input_ids
)[
-
1
]
+
past_length
,
dtype
=
tf
.
int32
)[
tf
.
newaxis
,
:]
# Attention mask.
if
attention_mask
is
not
None
:
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask
=
attention_mask
[:,
tf
.
newaxis
,
tf
.
newaxis
,
:]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask
=
tf
.
cast
(
attention_mask
,
tf
.
float32
)
attention_mask
=
(
1.0
-
attention_mask
)
*
-
10000.0
else
:
attention_mask
=
None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
if
head_mask
is
not
None
:
raise
NotImplementedError
else
:
head_mask
=
[
None
]
*
self
.
num_layers
if
token_type_ids
is
not
None
:
token_type_ids
=
tf
.
reshape
(
token_type_ids
,
[
-
1
,
shape_list
(
token_type_ids
)[
-
1
]])
token_type_embeds
=
self
.
w
(
token_type_ids
,
mode
=
'embedding'
)
token_type_embeds
*=
tf
.
math
.
sqrt
(
tf
.
cast
(
self
.
d_model_size
,
tf
.
float32
))
else
:
token_type_embeds
=
0
position_ids
=
tf
.
reshape
(
position_ids
,
[
-
1
,
shape_list
(
position_ids
)[
-
1
]])
inputs_embeds
=
self
.
w
(
input_ids
)
# x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
seq_len
=
input_shape
[
-
1
]
mask
=
1
-
tf
.
linalg
.
band_part
(
tf
.
ones
((
seq_len
,
seq_len
)),
-
1
,
0
)
inputs_embeds
*=
tf
.
math
.
sqrt
(
tf
.
cast
(
self
.
d_model_size
,
tf
.
float32
))
pos_embeds
=
tf
.
gather
(
self
.
pos_encoding
,
position_ids
)
hidden_states
=
inputs_embeds
+
pos_embeds
+
token_type_embeds
hidden_states
=
self
.
dropout
(
hidden_states
,
training
=
training
)
output_shape
=
input_shape
+
[
shape_list
(
hidden_states
)[
-
1
]]
presents
=
()
all_hidden_states
=
()
all_attentions
=
[]
for
i
,
(
h
,
layer_past
)
in
enumerate
(
zip
(
self
.
h
,
past
)):
if
self
.
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
tf
.
reshape
(
hidden_states
,
output_shape
),)
outputs
=
h
([
hidden_states
,
mask
,
layer_past
,
attention_mask
,
head_mask
[
i
]],
training
=
training
)
hidden_states
,
present
=
outputs
[:
2
]
presents
=
presents
+
(
present
,)
if
self
.
output_attentions
:
all_attentions
.
append
(
outputs
[
2
])
hidden_states
=
self
.
layernorm
(
hidden_states
)
hidden_states
=
tf
.
reshape
(
hidden_states
,
output_shape
)
if
self
.
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,)
outputs
=
(
hidden_states
,
presents
)
if
self
.
output_hidden_states
:
outputs
=
outputs
+
(
all_hidden_states
,)
if
self
.
output_attentions
:
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape
=
input_shape
[:
-
1
]
+
[
-
1
]
+
shape_list
(
all_attentions
[
0
])[
-
2
:]
all_attentions
=
tuple
(
tf
.
reshape
(
t
,
attention_output_shape
)
for
t
in
all_attentions
)
outputs
=
outputs
+
(
all_attentions
,)
return
outputs
class
TFCTRLPreTrainedModel
(
TFPreTrainedModel
):
class
TFCTRLPreTrainedModel
(
TFPreTrainedModel
):
""" An abstract class to handle weights initialization and
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
a simple interface for dowloading and loading pretrained models.
...
@@ -169,20 +326,7 @@ class TFCTRLPreTrainedModel(TFPreTrainedModel):
...
@@ -169,20 +326,7 @@ class TFCTRLPreTrainedModel(TFPreTrainedModel):
config_class
=
CTRLConfig
config_class
=
CTRLConfig
pretrained_model_archive_map
=
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map
=
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix
=
"transformer"
base_model_prefix
=
"transformer"
load_pt_weights
=
load_bert_pt_weights_in_tf2
load_pt_weights
=
load_ctrl_pt_weights_in_tf2
def
_init_weights
(
self
,
module
):
""" Initialize the weights.
"""
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Embedding
,
Conv1D
)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
if
isinstance
(
module
,
(
nn
.
Linear
,
Conv1D
))
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
elif
isinstance
(
module
,
nn
.
LayerNorm
):
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
CTRL_START_DOCSTRING
=
r
""" CTRL model was proposed in
CTRL_START_DOCSTRING
=
r
""" CTRL model was proposed in
...
@@ -240,172 +384,68 @@ CTRL_INPUTS_DOCSTRING = r""" Inputs:
...
@@ -240,172 +384,68 @@ CTRL_INPUTS_DOCSTRING = r""" Inputs:
class
TFCTRLModel
(
TFCTRLPreTrainedModel
):
class
TFCTRLModel
(
TFCTRLPreTrainedModel
):
r
"""
r
"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``t
orch.Float
Tensor`` of shape ``(batch_size, sequence_length, hidden_size)``
**last_hidden_state**: ``t
f.
Tensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the last layer of the model.
Sequence of hidden-states at the last layer of the model.
**past**:
**past**:
list of ``t
orch.Float
Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
list of ``t
f.
Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
that contains pre-computed hidden-states (key and values in the attention blocks).
that contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding.
Can be used (see `past` input) to speed up sequential decoding.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``t
orch.Float
Tensor`` (one for the output of each layer + the output of the embeddings)
list of ``t
f.
Tensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``t
orch.Float
Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
list of ``t
f.
Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
Examples::
import tensorflow as tf
from transformers import CTRLTokenizer, TFCTRLModel
tokenizer = CTRLTokenizer.from_pretrained('ctrl')
tokenizer = CTRLTokenizer.from_pretrained('ctrl')
model = CTRLModel.from_pretrained('ctrl')
model =
TF
CTRLModel.from_pretrained('ctrl')
input_ids = t
orch.tensor
(tokenizer.encode("
Links
Hello, my dog is cute"))
.unsqueeze(0)
# Batch size 1
input_ids = t
f.constant
(tokenizer.encode("Hello, my dog is cute"))
[None, :]
# Batch size 1
outputs = model(input_ids)
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
"""
def
__init__
(
self
,
config
,
**
kwargs
):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
TFCTRLModel
,
self
).
__init__
(
**
kwargs
)
super
(
TFCTRLModel
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
transformer
=
TFCTRLMainLayer
(
config
,
name
=
'transformer'
)
self
.
d_model_size
=
config
.
n_embd
self
.
num_layers
=
config
.
n_layer
self
.
pos_encoding
=
positional_encoding
(
config
.
n_positions
,
self
.
d_model_size
,
torch
.
float
)
self
.
output_attentions
=
config
.
output_attentions
self
.
w
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
n_embd
)
def
call
(
self
,
inputs
,
**
kwargs
):
self
.
dropout
=
nn
.
Dropout
(
config
.
embd_pdrop
)
outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
self
.
h
=
nn
.
ModuleList
([
EncoderLayer
(
config
.
n_embd
,
return
outputs
config
.
n_head
,
config
.
dff
,
config
.
resid_pdrop
,
config
.
output_attentions
)
for
_
in
range
(
config
.
n_layer
)])
self
.
layernorm
=
nn
.
LayerNorm
(
config
.
n_embd
,
eps
=
config
.
layer_norm_epsilon
)
self
.
init_weights
()
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
self
.
w
=
self
.
_get_resized_embeddings
(
self
.
w
,
new_num_tokens
)
return
self
.
w
def
_prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
h
[
layer
].
attn
.
prune_heads
(
heads
)
def
forward
(
self
,
input_ids
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
):
input_shape
=
input_ids
.
size
()
input_ids
=
input_ids
.
view
(
-
1
,
input_shape
[
-
1
])
if
token_type_ids
is
not
None
:
token_type_ids
=
token_type_ids
.
view
(
-
1
,
input_shape
[
-
1
])
if
position_ids
is
not
None
:
position_ids
=
position_ids
.
view
(
-
1
,
input_shape
[
-
1
])
if
past
is
None
:
past_length
=
0
past
=
[
None
]
*
len
(
self
.
h
)
else
:
past_length
=
past
[
0
][
0
].
size
(
-
2
)
if
position_ids
is
None
:
position_ids
=
torch
.
arange
(
past_length
,
input_ids
.
size
(
-
1
)
+
past_length
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
input_ids
)
# Attention mask.
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
.
view
(
-
1
,
input_shape
[
-
1
])
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask
=
attention_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask
=
attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
attention_mask
=
(
1.0
-
attention_mask
)
*
-
10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
self
.
config
.
n_layer
x
=
self
.
w
(
input_ids
)
# x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
seq_len
=
input_ids
.
shape
[
1
]
mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
),
1
).
to
(
x
.
device
)
x
*=
np
.
sqrt
(
self
.
d_model_size
)
pos_x
=
self
.
pos_encoding
[
position_ids
,
:].
to
(
x
.
device
)
x
+=
pos_x
x
=
self
.
dropout
(
x
)
output_shape
=
input_shape
+
(
x
.
size
(
-
1
),)
class
TFCTRLLMHead
(
tf
.
keras
.
layers
.
Layer
):
presents
=
()
def
__init__
(
self
,
config
,
input_embeddings
,
**
kwargs
):
all_hidden_states
=
()
super
(
TFCTRLLMHead
,
self
).
__init__
(
**
kwargs
)
all_attentions
=
[]
self
.
vocab_size
=
config
.
vocab_size
for
i
,
(
h
,
layer_past
)
in
enumerate
(
zip
(
self
.
h
,
past
)):
if
self
.
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
x
.
view
(
*
output_shape
),)
outputs
=
h
(
x
,
mask
,
layer_past
=
layer_past
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
[
i
])
x
,
present
=
outputs
[:
2
]
presents
=
presents
+
(
present
,)
if
self
.
output_attentions
:
# The output weights are the same as the input embeddings, but there is
all_attentions
.
append
(
outputs
[
2
])
# an output-only bias for each token.
self
.
input_embeddings
=
input_embeddings
x
=
self
.
layernorm
(
x
)
def
build
(
self
,
input_shape
):
x
=
x
.
view
(
*
output_shape
)
self
.
bias
=
self
.
add_weight
(
shape
=
(
self
.
vocab_size
,),
if
self
.
output_hidden_states
:
initializer
=
'zeros'
,
all_hidden_states
=
all_hidden_states
+
(
x
,)
trainable
=
True
,
name
=
'bias'
)
super
(
TFCTRLLMHead
,
self
).
build
(
input_shape
)
outputs
=
(
x
,
presents
)
def
call
(
self
,
hidden_states
):
if
self
.
output_hidden_states
:
hidden_states
=
self
.
input_embeddings
(
hidden_states
,
mode
=
"linear"
)
outputs
=
outputs
+
(
all_hidden_states
,)
hidden_states
=
hidden_states
+
self
.
bias
if
self
.
output_attentions
:
return
hidden_states
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape
=
input_shape
[:
-
1
]
+
(
-
1
,)
+
all_attentions
[
0
].
shape
[
-
2
:]
all_attentions
=
tuple
(
t
.
view
(
*
attention_output_shape
)
for
t
in
all_attentions
)
outputs
=
outputs
+
(
all_attentions
,)
return
outputs
@
add_start_docstrings
(
"""The CTRL Model transformer with a language modeling head on top
@
add_start_docstrings
(
"""The CTRL Model transformer with a language modeling head on top
(linear layer with weights tied to the input embeddings). """
,
CTRL_START_DOCSTRING
,
CTRL_INPUTS_DOCSTRING
)
(linear layer with weights tied to the input embeddings). """
,
CTRL_START_DOCSTRING
,
CTRL_INPUTS_DOCSTRING
)
class
CTRLLMHeadModel
(
CTRLPreTrainedModel
):
class
TF
CTRLLMHeadModel
(
TF
CTRLPreTrainedModel
):
r
"""
r
"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for language modeling.
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
Indices are selected in ``[-1, 0, ..., config.vocab_size]``
All labels set to ``-1`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Language modeling loss.
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**past**:
**past**:
...
@@ -423,53 +463,28 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
...
@@ -423,53 +463,28 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
Examples::
Examples::
import torch
import torch
from transformers import CTRLTokenizer, CTRLLMHeadModel
from transformers import CTRLTokenizer,
TF
CTRLLMHeadModel
tokenizer = CTRLTokenizer.from_pretrained('ctrl')
tokenizer = CTRLTokenizer.from_pretrained('ctrl')
model = CTRLLMHeadModel.from_pretrained('ctrl')
model =
TF
CTRLLMHeadModel.from_pretrained('ctrl')
input_ids = torch.tensor(tokenizer.encode("Links Hello, my dog is cute")).unsqueeze(0) # Batch size 1
input_ids = torch.tensor(tokenizer.encode("Links Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=input_ids)
outputs = model(input_ids, labels=input_ids)
loss, logits = outputs[:2]
loss, logits = outputs[:2]
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
CTRLLMHeadModel
,
self
).
__init__
(
config
)
super
(
TFCTRLLMHeadModel
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
transformer
=
CTRLModel
(
config
)
self
.
transformer
=
TFCTRLMainLayer
(
config
,
name
=
'transformer'
)
self
.
lm_head
=
nn
.
Linear
(
config
.
n_embd
,
config
.
vocab_size
,
bias
=
True
)
self
.
init_weights
()
self
.
tie_weights
()
def
tie_weights
(
self
):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self
.
_tie_or_clone_weights
(
self
.
lm_head
,
self
.
transformer
.
w
)
def
forward
(
self
,
input_ids
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
self
.
lm_head
=
TFCTRLLMHead
(
config
,
self
.
transformer
.
w
,
name
=
"lm_head"
)
labels
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
past
=
past
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
head_mask
=
head_mask
)
def
call
(
self
,
inputs
,
**
kwargs
):
transformer_outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
hidden_states
=
transformer_outputs
[
0
]
hidden_states
=
transformer_outputs
[
0
]
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
outputs
=
(
lm_logits
,)
+
transformer_outputs
[
1
:]
outputs
=
(
lm_logits
,)
+
transformer_outputs
[
1
:]
if
labels
is
not
None
:
return
outputs
# lm_logits, presents, (all hidden_states), (attentions)
# Shift so that tokens < n predict n
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
labels
[...,
1
:].
contiguous
()
# Flatten the tokens
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), lm_logits, presents, (all hidden_states), (attentions)
transformers/tests/modeling_tf_ctrl_test.py
0 → 100644
View file @
c56d921d
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
unittest
import
shutil
import
pytest
import
sys
from
.modeling_tf_common_test
import
(
TFCommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
from
transformers
import
CTRLConfig
,
is_tf_available
if
is_tf_available
():
import
tensorflow
as
tf
from
transformers.modeling_tf_ctrl
import
(
TFCTRLModel
,
TFCTRLLMHeadModel
,
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
)
else
:
pytestmark
=
pytest
.
mark
.
skip
(
"Require TensorFlow"
)
class
TFCTRLModelTest
(
TFCommonTestCases
.
TFCommonModelTester
):
all_model_classes
=
(
TFCTRLModel
,
TFCTRLLMHeadModel
)
if
is_tf_available
()
else
()
class
TFCTRLModelTester
(
object
):
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
is_training
=
True
,
use_token_type_ids
=
True
,
use_input_mask
=
True
,
use_labels
=
True
,
use_mc_token_ids
=
True
,
vocab_size
=
99
,
hidden_size
=
32
,
num_hidden_layers
=
5
,
num_attention_heads
=
4
,
intermediate_size
=
37
,
hidden_act
=
"gelu"
,
hidden_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
max_position_embeddings
=
512
,
type_vocab_size
=
16
,
type_sequence_label_size
=
2
,
initializer_range
=
0.02
,
num_labels
=
3
,
num_choices
=
4
,
scope
=
None
,
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
is_training
=
is_training
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_input_mask
=
use_input_mask
self
.
use_labels
=
use_labels
self
.
use_mc_token_ids
=
use_mc_token_ids
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
intermediate_size
=
intermediate_size
self
.
hidden_act
=
hidden_act
self
.
hidden_dropout_prob
=
hidden_dropout_prob
self
.
attention_probs_dropout_prob
=
attention_probs_dropout_prob
self
.
max_position_embeddings
=
max_position_embeddings
self
.
type_vocab_size
=
type_vocab_size
self
.
type_sequence_label_size
=
type_sequence_label_size
self
.
initializer_range
=
initializer_range
self
.
num_labels
=
num_labels
self
.
num_choices
=
num_choices
self
.
scope
=
scope
def
prepare_config_and_inputs
(
self
):
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_mask
=
None
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
vocab_size
=
2
)
token_type_ids
=
None
if
self
.
use_token_type_ids
:
token_type_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
mc_token_ids
=
None
if
self
.
use_mc_token_ids
:
mc_token_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
num_choices
],
self
.
seq_length
)
sequence_labels
=
None
token_labels
=
None
choice_labels
=
None
if
self
.
use_labels
:
sequence_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choice_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
config
=
CTRLConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
n_embd
=
self
.
hidden_size
,
n_layer
=
self
.
num_hidden_layers
,
n_head
=
self
.
num_attention_heads
,
# intermediate_size=self.intermediate_size,
# hidden_act=self.hidden_act,
# hidden_dropout_prob=self.hidden_dropout_prob,
# attention_probs_dropout_prob=self.attention_probs_dropout_prob,
n_positions
=
self
.
max_position_embeddings
,
n_ctx
=
self
.
max_position_embeddings
# type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range
)
head_mask
=
ids_tensor
([
self
.
num_hidden_layers
,
self
.
num_attention_heads
],
2
)
return
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
mc_token_ids
,
sequence_labels
,
token_labels
,
choice_labels
def
create_and_check_ctrl_model
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
model
=
TFCTRLModel
(
config
=
config
)
inputs
=
{
'input_ids'
:
input_ids
,
'attention_mask'
:
input_mask
,
'token_type_ids'
:
token_type_ids
}
sequence_output
=
model
(
inputs
)[
0
]
inputs
=
[
input_ids
,
None
,
input_mask
]
# None is the input for 'past'
sequence_output
=
model
(
inputs
)[
0
]
sequence_output
=
model
(
input_ids
)[
0
]
result
=
{
"sequence_output"
:
sequence_output
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"sequence_output"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
def
create_and_check_ctrl_lm_head
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
model
=
TFCTRLLMHeadModel
(
config
=
config
)
inputs
=
{
'input_ids'
:
input_ids
,
'attention_mask'
:
input_mask
,
'token_type_ids'
:
token_type_ids
}
prediction_scores
=
model
(
inputs
)[
0
]
result
=
{
"prediction_scores"
:
prediction_scores
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"prediction_scores"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
mc_token_ids
,
sequence_labels
,
token_labels
,
choice_labels
)
=
config_and_inputs
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'attention_mask'
:
input_mask
}
return
config
,
inputs_dict
def
setUp
(
self
):
self
.
model_tester
=
TFCTRLModelTest
.
TFCTRLModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
CTRLConfig
,
n_embd
=
37
)
def
test_config
(
self
):
self
.
config_tester
.
run_common_tests
()
def
test_ctrl_model
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_ctrl_model
(
*
config_and_inputs
)
def
test_ctrl_lm_head
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_ctrl_lm_head
(
*
config_and_inputs
)
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/transformers_test/"
for
model_name
in
list
(
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
TFCTRLModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
if
__name__
==
"__main__"
:
unittest
.
main
()
transformers/tests/modeling_tf_gpt2_test.py
View file @
c56d921d
...
@@ -222,7 +222,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -222,7 +222,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
@
pytest
.
mark
.
slow
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/transformers_test/"
cache_dir
=
"/tmp/transformers_test/"
for
model_name
in
list
(
TF_
gpt
2_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
for
model_name
in
list
(
TF_
GPT
2_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
TFGPT2Model
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
model
=
TFGPT2Model
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment