Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
70d97ddd
Commit
70d97ddd
authored
Nov 11, 2019
by
Julien Chaumond
Browse files
[TF models] Common attributes as per #1721
parent
872403be
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
84 additions
and
0 deletions
+84
-0
transformers/modeling_tf_bert.py
transformers/modeling_tf_bert.py
+9
-0
transformers/modeling_tf_ctrl.py
transformers/modeling_tf_ctrl.py
+6
-0
transformers/modeling_tf_distilbert.py
transformers/modeling_tf_distilbert.py
+6
-0
transformers/modeling_tf_gpt2.py
transformers/modeling_tf_gpt2.py
+9
-0
transformers/modeling_tf_openai.py
transformers/modeling_tf_openai.py
+9
-0
transformers/modeling_tf_roberta.py
transformers/modeling_tf_roberta.py
+6
-0
transformers/modeling_tf_transfo_xl.py
transformers/modeling_tf_transfo_xl.py
+3
-0
transformers/modeling_tf_utils.py
transformers/modeling_tf_utils.py
+15
-0
transformers/modeling_tf_xlm.py
transformers/modeling_tf_xlm.py
+5
-0
transformers/modeling_tf_xlnet.py
transformers/modeling_tf_xlnet.py
+6
-0
transformers/tests/modeling_tf_common_test.py
transformers/tests/modeling_tf_common_test.py
+10
-0
No files found.
transformers/modeling_tf_bert.py
View file @
70d97ddd
...
@@ -460,6 +460,9 @@ class TFBertMainLayer(tf.keras.layers.Layer):
...
@@ -460,6 +460,9 @@ class TFBertMainLayer(tf.keras.layers.Layer):
self
.
encoder
=
TFBertEncoder
(
config
,
name
=
'encoder'
)
self
.
encoder
=
TFBertEncoder
(
config
,
name
=
'encoder'
)
self
.
pooler
=
TFBertPooler
(
config
,
name
=
'pooler'
)
self
.
pooler
=
TFBertPooler
(
config
,
name
=
'pooler'
)
def
get_input_embeddings
(
self
):
return
self
.
embeddings
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
raise
NotImplementedError
raise
NotImplementedError
...
@@ -702,6 +705,9 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
...
@@ -702,6 +705,9 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
self
.
nsp
=
TFBertNSPHead
(
config
,
name
=
'nsp___cls'
)
self
.
nsp
=
TFBertNSPHead
(
config
,
name
=
'nsp___cls'
)
self
.
mlm
=
TFBertMLMHead
(
config
,
self
.
bert
.
embeddings
,
name
=
'mlm___cls'
)
self
.
mlm
=
TFBertMLMHead
(
config
,
self
.
bert
.
embeddings
,
name
=
'mlm___cls'
)
def
get_output_embeddings
(
self
):
return
self
.
bert
.
embeddings
def
call
(
self
,
inputs
,
**
kwargs
):
def
call
(
self
,
inputs
,
**
kwargs
):
outputs
=
self
.
bert
(
inputs
,
**
kwargs
)
outputs
=
self
.
bert
(
inputs
,
**
kwargs
)
...
@@ -747,6 +753,9 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
...
@@ -747,6 +753,9 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
self
.
bert
=
TFBertMainLayer
(
config
,
name
=
'bert'
)
self
.
bert
=
TFBertMainLayer
(
config
,
name
=
'bert'
)
self
.
mlm
=
TFBertMLMHead
(
config
,
self
.
bert
.
embeddings
,
name
=
'mlm___cls'
)
self
.
mlm
=
TFBertMLMHead
(
config
,
self
.
bert
.
embeddings
,
name
=
'mlm___cls'
)
def
get_output_embeddings
(
self
):
return
self
.
bert
.
embeddings
def
call
(
self
,
inputs
,
**
kwargs
):
def
call
(
self
,
inputs
,
**
kwargs
):
outputs
=
self
.
bert
(
inputs
,
**
kwargs
)
outputs
=
self
.
bert
(
inputs
,
**
kwargs
)
...
...
transformers/modeling_tf_ctrl.py
View file @
70d97ddd
...
@@ -192,6 +192,9 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
...
@@ -192,6 +192,9 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
name
=
'h_._{}'
.
format
(
i
))
for
i
in
range
(
config
.
n_layer
)]
name
=
'h_._{}'
.
format
(
i
))
for
i
in
range
(
config
.
n_layer
)]
self
.
layernorm
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
config
.
layer_norm_epsilon
,
name
=
"layernorm"
)
self
.
layernorm
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
config
.
layer_norm_epsilon
,
name
=
"layernorm"
)
def
get_input_embeddings
(
self
):
return
self
.
w
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
raise
NotImplementedError
raise
NotImplementedError
...
@@ -480,6 +483,9 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
...
@@ -480,6 +483,9 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
self
.
lm_head
=
TFCTRLLMHead
(
config
,
self
.
transformer
.
w
,
name
=
"lm_head"
)
self
.
lm_head
=
TFCTRLLMHead
(
config
,
self
.
transformer
.
w
,
name
=
"lm_head"
)
def
get_output_embeddings
(
self
):
return
self
.
lm_head
.
input_embeddings
def
call
(
self
,
inputs
,
**
kwargs
):
def
call
(
self
,
inputs
,
**
kwargs
):
transformer_outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
transformer_outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
hidden_states
=
transformer_outputs
[
0
]
hidden_states
=
transformer_outputs
[
0
]
...
...
transformers/modeling_tf_distilbert.py
View file @
70d97ddd
...
@@ -398,6 +398,9 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
...
@@ -398,6 +398,9 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
self
.
embeddings
=
TFEmbeddings
(
config
,
name
=
"embeddings"
)
# Embeddings
self
.
embeddings
=
TFEmbeddings
(
config
,
name
=
"embeddings"
)
# Embeddings
self
.
transformer
=
TFTransformer
(
config
,
name
=
"transformer"
)
# Encoder
self
.
transformer
=
TFTransformer
(
config
,
name
=
"transformer"
)
# Encoder
def
get_input_embeddings
(
self
):
return
self
.
embeddings
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
raise
NotImplementedError
raise
NotImplementedError
...
@@ -613,6 +616,9 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel):
...
@@ -613,6 +616,9 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel):
self
.
vocab_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
1e-12
,
name
=
"vocab_layer_norm"
)
self
.
vocab_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
1e-12
,
name
=
"vocab_layer_norm"
)
self
.
vocab_projector
=
TFDistilBertLMHead
(
config
,
self
.
distilbert
.
embeddings
,
name
=
"vocab_projector"
)
self
.
vocab_projector
=
TFDistilBertLMHead
(
config
,
self
.
distilbert
.
embeddings
,
name
=
"vocab_projector"
)
def
get_output_embeddings
(
self
):
return
self
.
vocab_projector
.
input_embeddings
def
call
(
self
,
inputs
,
**
kwargs
):
def
call
(
self
,
inputs
,
**
kwargs
):
distilbert_output
=
self
.
distilbert
(
inputs
,
**
kwargs
)
distilbert_output
=
self
.
distilbert
(
inputs
,
**
kwargs
)
...
...
transformers/modeling_tf_gpt2.py
View file @
70d97ddd
...
@@ -219,6 +219,9 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
...
@@ -219,6 +219,9 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
name
=
'h_._{}'
.
format
(
i
))
for
i
in
range
(
config
.
n_layer
)]
name
=
'h_._{}'
.
format
(
i
))
for
i
in
range
(
config
.
n_layer
)]
self
.
ln_f
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
config
.
layer_norm_epsilon
,
name
=
'ln_f'
)
self
.
ln_f
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
config
.
layer_norm_epsilon
,
name
=
'ln_f'
)
def
get_input_embeddings
(
self
):
return
self
.
wte
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
raise
NotImplementedError
raise
NotImplementedError
...
@@ -490,6 +493,9 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel):
...
@@ -490,6 +493,9 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel):
super
(
TFGPT2LMHeadModel
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
super
(
TFGPT2LMHeadModel
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
transformer
=
TFGPT2MainLayer
(
config
,
name
=
'transformer'
)
self
.
transformer
=
TFGPT2MainLayer
(
config
,
name
=
'transformer'
)
def
get_output_embeddings
(
self
):
return
self
.
transformer
.
wte
def
call
(
self
,
inputs
,
**
kwargs
):
def
call
(
self
,
inputs
,
**
kwargs
):
transformer_outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
transformer_outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
hidden_states
=
transformer_outputs
[
0
]
hidden_states
=
transformer_outputs
[
0
]
...
@@ -560,6 +566,9 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
...
@@ -560,6 +566,9 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
self
.
transformer
=
TFGPT2MainLayer
(
config
,
name
=
'transformer'
)
self
.
transformer
=
TFGPT2MainLayer
(
config
,
name
=
'transformer'
)
self
.
multiple_choice_head
=
TFSequenceSummary
(
config
,
initializer_range
=
config
.
initializer_range
,
name
=
'multiple_choice_head'
)
self
.
multiple_choice_head
=
TFSequenceSummary
(
config
,
initializer_range
=
config
.
initializer_range
,
name
=
'multiple_choice_head'
)
def
get_output_embeddings
(
self
):
return
self
.
transformer
.
wte
def
call
(
self
,
inputs
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
mc_token_ids
=
None
,
training
=
False
):
def
call
(
self
,
inputs
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
mc_token_ids
=
None
,
training
=
False
):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
input_ids
=
inputs
[
0
]
...
...
transformers/modeling_tf_openai.py
View file @
70d97ddd
...
@@ -217,6 +217,9 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
...
@@ -217,6 +217,9 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
scale
=
True
,
scale
=
True
,
name
=
'h_._{}'
.
format
(
i
))
for
i
in
range
(
config
.
n_layer
)]
name
=
'h_._{}'
.
format
(
i
))
for
i
in
range
(
config
.
n_layer
)]
def
get_input_embeddings
(
self
):
return
self
.
tokens_embed
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
raise
NotImplementedError
raise
NotImplementedError
...
@@ -462,6 +465,9 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel):
...
@@ -462,6 +465,9 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel):
super
(
TFOpenAIGPTLMHeadModel
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
super
(
TFOpenAIGPTLMHeadModel
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
transformer
=
TFOpenAIGPTMainLayer
(
config
,
name
=
'transformer'
)
self
.
transformer
=
TFOpenAIGPTMainLayer
(
config
,
name
=
'transformer'
)
def
get_output_embeddings
(
self
):
return
self
.
transformer
.
tokens_embed
def
call
(
self
,
inputs
,
**
kwargs
):
def
call
(
self
,
inputs
,
**
kwargs
):
transformer_outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
transformer_outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
hidden_states
=
transformer_outputs
[
0
]
hidden_states
=
transformer_outputs
[
0
]
...
@@ -524,6 +530,9 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
...
@@ -524,6 +530,9 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
self
.
transformer
=
TFOpenAIGPTMainLayer
(
config
,
name
=
'transformer'
)
self
.
transformer
=
TFOpenAIGPTMainLayer
(
config
,
name
=
'transformer'
)
self
.
multiple_choice_head
=
TFSequenceSummary
(
config
,
initializer_range
=
config
.
initializer_range
,
name
=
'multiple_choice_head'
)
self
.
multiple_choice_head
=
TFSequenceSummary
(
config
,
initializer_range
=
config
.
initializer_range
,
name
=
'multiple_choice_head'
)
def
get_output_embeddings
(
self
):
return
self
.
transformer
.
tokens_embed
def
call
(
self
,
inputs
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
mc_token_ids
=
None
,
training
=
False
):
def
call
(
self
,
inputs
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
mc_token_ids
=
None
,
training
=
False
):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
if
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
input_ids
=
inputs
[
0
]
...
...
transformers/modeling_tf_roberta.py
View file @
70d97ddd
...
@@ -65,6 +65,9 @@ class TFRobertaMainLayer(TFBertMainLayer):
...
@@ -65,6 +65,9 @@ class TFRobertaMainLayer(TFBertMainLayer):
super
(
TFRobertaMainLayer
,
self
).
__init__
(
config
,
**
kwargs
)
super
(
TFRobertaMainLayer
,
self
).
__init__
(
config
,
**
kwargs
)
self
.
embeddings
=
TFRobertaEmbeddings
(
config
,
name
=
'embeddings'
)
self
.
embeddings
=
TFRobertaEmbeddings
(
config
,
name
=
'embeddings'
)
def
get_input_embeddings
(
self
):
return
self
.
embeddings
class
TFRobertaPreTrainedModel
(
TFPreTrainedModel
):
class
TFRobertaPreTrainedModel
(
TFPreTrainedModel
):
""" An abstract class to handle weights initialization and
""" An abstract class to handle weights initialization and
...
@@ -280,6 +283,9 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel):
...
@@ -280,6 +283,9 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel):
self
.
roberta
=
TFRobertaMainLayer
(
config
,
name
=
"roberta"
)
self
.
roberta
=
TFRobertaMainLayer
(
config
,
name
=
"roberta"
)
self
.
lm_head
=
TFRobertaLMHead
(
config
,
self
.
roberta
.
embeddings
,
name
=
"lm_head"
)
self
.
lm_head
=
TFRobertaLMHead
(
config
,
self
.
roberta
.
embeddings
,
name
=
"lm_head"
)
def
get_output_embeddings
(
self
):
return
self
.
lm_head
.
decoder
def
call
(
self
,
inputs
,
**
kwargs
):
def
call
(
self
,
inputs
,
**
kwargs
):
outputs
=
self
.
roberta
(
inputs
,
**
kwargs
)
outputs
=
self
.
roberta
(
inputs
,
**
kwargs
)
...
...
transformers/modeling_tf_transfo_xl.py
View file @
70d97ddd
...
@@ -413,6 +413,9 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
...
@@ -413,6 +413,9 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
name
=
'r_r_bias'
)
name
=
'r_r_bias'
)
super
(
TFTransfoXLMainLayer
,
self
).
build
(
input_shape
)
super
(
TFTransfoXLMainLayer
,
self
).
build
(
input_shape
)
def
get_input_embeddings
(
self
):
return
self
.
word_emb
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
return
self
.
word_emb
return
self
.
word_emb
...
...
transformers/modeling_tf_utils.py
View file @
70d97ddd
...
@@ -65,6 +65,21 @@ class TFPreTrainedModel(tf.keras.Model):
...
@@ -65,6 +65,21 @@ class TFPreTrainedModel(tf.keras.Model):
# Save config in model
# Save config in model
self
.
config
=
config
self
.
config
=
config
def
get_input_embeddings
(
self
):
""" Get model's input embeddings
"""
base_model
=
getattr
(
self
,
self
.
base_model_prefix
,
self
)
if
base_model
is
not
self
:
return
base_model
.
get_input_embeddings
()
else
:
raise
NotImplementedError
def
get_output_embeddings
(
self
):
""" Get model's output embeddings
Return None if the model doesn't have output embeddings
"""
return
None
# Overwrite for models with output embeddings
def
_get_resized_embeddings
(
self
,
old_embeddings
,
new_num_tokens
=
None
):
def
_get_resized_embeddings
(
self
,
old_embeddings
,
new_num_tokens
=
None
):
""" Build a resized Embedding Variable from a provided token Embedding Module.
""" Build a resized Embedding Variable from a provided token Embedding Module.
Increasing the size will add newly initialized vectors at the end
Increasing the size will add newly initialized vectors at the end
...
...
transformers/modeling_tf_xlm.py
View file @
70d97ddd
...
@@ -277,6 +277,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
...
@@ -277,6 +277,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
self
.
prune_heads
({
int
(
layer
):
list
(
map
(
int
,
heads
))})
self
.
prune_heads
({
int
(
layer
):
list
(
map
(
int
,
heads
))})
def
get_input_embeddings
(
self
):
return
self
.
embeddings
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
raise
NotImplementedError
raise
NotImplementedError
...
@@ -641,6 +644,8 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
...
@@ -641,6 +644,8 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
self
.
transformer
=
TFXLMMainLayer
(
config
,
name
=
'transformer'
)
self
.
transformer
=
TFXLMMainLayer
(
config
,
name
=
'transformer'
)
self
.
pred_layer
=
TFXLMPredLayer
(
config
,
self
.
transformer
.
embeddings
,
name
=
'pred_layer_._proj'
)
self
.
pred_layer
=
TFXLMPredLayer
(
config
,
self
.
transformer
.
embeddings
,
name
=
'pred_layer_._proj'
)
def
get_output_embeddings
(
self
):
return
self
.
pred_layer
.
input_embeddings
def
call
(
self
,
inputs
,
**
kwargs
):
def
call
(
self
,
inputs
,
**
kwargs
):
transformer_outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
transformer_outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
...
...
transformers/modeling_tf_xlnet.py
View file @
70d97ddd
...
@@ -371,6 +371,9 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -371,6 +371,9 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
self
.
layer
=
[
TFXLNetLayer
(
config
,
name
=
'layer_._{}'
.
format
(
i
))
for
i
in
range
(
config
.
n_layer
)]
self
.
layer
=
[
TFXLNetLayer
(
config
,
name
=
'layer_._{}'
.
format
(
i
))
for
i
in
range
(
config
.
n_layer
)]
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
dropout
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
dropout
)
def
get_input_embeddings
(
self
):
return
self
.
word_embedding
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
):
initializer
=
get_initializer
(
self
.
initializer_range
)
initializer
=
get_initializer
(
self
.
initializer_range
)
self
.
mask_emb
=
self
.
add_weight
(
shape
=
(
1
,
1
,
self
.
d_model
),
self
.
mask_emb
=
self
.
add_weight
(
shape
=
(
1
,
1
,
self
.
d_model
),
...
@@ -854,6 +857,9 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
...
@@ -854,6 +857,9 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
self
.
transformer
=
TFXLNetMainLayer
(
config
,
name
=
'transformer'
)
self
.
transformer
=
TFXLNetMainLayer
(
config
,
name
=
'transformer'
)
self
.
lm_loss
=
TFXLNetLMHead
(
config
,
self
.
transformer
.
word_embedding
,
name
=
'lm_loss'
)
self
.
lm_loss
=
TFXLNetLMHead
(
config
,
self
.
transformer
.
word_embedding
,
name
=
'lm_loss'
)
def
get_output_embeddings
(
self
):
return
self
.
lm_loss
.
input_embeddings
def
call
(
self
,
inputs
,
**
kwargs
):
def
call
(
self
,
inputs
,
**
kwargs
):
transformer_outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
transformer_outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
hidden_state
=
transformer_outputs
[
0
]
hidden_state
=
transformer_outputs
[
0
]
...
...
transformers/tests/modeling_tf_common_test.py
View file @
70d97ddd
...
@@ -360,6 +360,16 @@ class TFCommonTestCases:
...
@@ -360,6 +360,16 @@ class TFCommonTestCases:
# self.assertTrue(models_equal)
# self.assertTrue(models_equal)
def
test_model_common_attributes
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
model
=
model_class
(
config
)
assert
isinstance
(
model
.
get_input_embeddings
(),
tf
.
keras
.
layers
.
Layer
)
x
=
model
.
get_output_embeddings
()
assert
x
is
None
or
instanceof
(
x
,
tf
.
keras
.
layers
.
Layer
)
def
test_tie_model_weights
(
self
):
def
test_tie_model_weights
(
self
):
pass
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment