Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1724cee8
Commit
1724cee8
authored
Nov 04, 2019
by
thomwolf
Browse files
switch from properties to methods
parent
9b45d0f8
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
70 additions
and
75 deletions
+70
-75
templates/adding_a_new_model/modeling_xxx.py
templates/adding_a_new_model/modeling_xxx.py
+3
-5
transformers/modeling_bert.py
transformers/modeling_bert.py
+5
-9
transformers/modeling_ctrl.py
transformers/modeling_ctrl.py
+3
-6
transformers/modeling_distilbert.py
transformers/modeling_distilbert.py
+3
-6
transformers/modeling_gpt2.py
transformers/modeling_gpt2.py
+4
-8
transformers/modeling_openai.py
transformers/modeling_openai.py
+4
-8
transformers/modeling_roberta.py
transformers/modeling_roberta.py
+4
-4
transformers/modeling_transfo_xl.py
transformers/modeling_transfo_xl.py
+2
-4
transformers/modeling_utils.py
transformers/modeling_utils.py
+27
-10
transformers/modeling_xlm.py
transformers/modeling_xlm.py
+3
-6
transformers/modeling_xlnet.py
transformers/modeling_xlnet.py
+3
-6
transformers/tests/modeling_common_test.py
transformers/tests/modeling_common_test.py
+9
-3
No files found.
templates/adding_a_new_model/modeling_xxx.py
View file @
1724cee8
...
@@ -281,11 +281,10 @@ class XxxModel(XxxPreTrainedModel):
...
@@ -281,11 +281,10 @@ class XxxModel(XxxPreTrainedModel):
self
.
init_weights
()
self
.
init_weights
()
@
property
@
property
def
input_embeddings
(
self
):
def
get_
input_embeddings
(
self
):
return
self
.
embeddings
.
word_embeddings
return
self
.
embeddings
.
word_embeddings
@
input_embeddings
.
setter
def
set_input_embeddings
(
self
,
new_embeddings
):
def
input_embeddings
(
self
,
new_embeddings
):
self
.
embeddings
.
word_embeddings
=
new_embeddings
self
.
embeddings
.
word_embeddings
=
new_embeddings
def
_prune_heads
(
self
,
heads_to_prune
):
def
_prune_heads
(
self
,
heads_to_prune
):
...
@@ -382,8 +381,7 @@ class XxxForMaskedLM(XxxPreTrainedModel):
...
@@ -382,8 +381,7 @@ class XxxForMaskedLM(XxxPreTrainedModel):
self
.
init_weights
()
self
.
init_weights
()
@
property
def
get_output_embeddings
(
self
):
def
output_embeddings
(
self
):
return
self
.
lm_head
return
self
.
lm_head
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
...
...
transformers/modeling_bert.py
View file @
1724cee8
...
@@ -601,13 +601,11 @@ class BertModel(BertPreTrainedModel):
...
@@ -601,13 +601,11 @@ class BertModel(BertPreTrainedModel):
self
.
init_weights
()
self
.
init_weights
()
@
property
def
get_input_embeddings
(
self
):
def
input_embeddings
(
self
):
return
self
.
embeddings
.
word_embeddings
return
self
.
embeddings
.
word_embeddings
@
input_embeddings
.
setter
def
set_input_embeddings
(
self
,
value
):
def
input_embeddings
(
self
,
new_embeddings
):
self
.
embeddings
.
word_embeddings
=
value
self
.
embeddings
.
word_embeddings
=
new_embeddings
def
_prune_heads
(
self
,
heads_to_prune
):
def
_prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
""" Prunes heads of the model.
...
@@ -753,8 +751,7 @@ class BertForPreTraining(BertPreTrainedModel):
...
@@ -753,8 +751,7 @@ class BertForPreTraining(BertPreTrainedModel):
self
.
init_weights
()
self
.
init_weights
()
@
property
def
get_output_embeddings
(
self
):
def
output_embeddings
(
self
):
return
self
.
cls
.
predictions
.
decoder
return
self
.
cls
.
predictions
.
decoder
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
...
@@ -829,8 +826,7 @@ class BertForMaskedLM(BertPreTrainedModel):
...
@@ -829,8 +826,7 @@ class BertForMaskedLM(BertPreTrainedModel):
self
.
init_weights
()
self
.
init_weights
()
@
property
def
get_output_embeddings
(
self
):
def
output_embeddings
(
self
):
return
self
.
cls
.
predictions
.
decoder
return
self
.
cls
.
predictions
.
decoder
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
...
...
transformers/modeling_ctrl.py
View file @
1724cee8
...
@@ -289,12 +289,10 @@ class CTRLModel(CTRLPreTrainedModel):
...
@@ -289,12 +289,10 @@ class CTRLModel(CTRLPreTrainedModel):
self
.
init_weights
()
self
.
init_weights
()
@
property
def
get_input_embeddings
(
self
):
def
input_embeddings
(
self
):
return
self
.
w
return
self
.
w
@
input_embeddings
.
setter
def
set_input_embeddings
(
self
,
new_embeddings
):
def
input_embeddings
(
self
,
new_embeddings
):
self
.
w
=
new_embeddings
self
.
w
=
new_embeddings
def
_prune_heads
(
self
,
heads_to_prune
):
def
_prune_heads
(
self
,
heads_to_prune
):
...
@@ -454,8 +452,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
...
@@ -454,8 +452,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
self
.
init_weights
()
self
.
init_weights
()
@
property
def
get_output_embeddings
(
self
):
def
output_embeddings
(
self
):
return
self
.
lm_head
return
self
.
lm_head
def
forward
(
self
,
input_ids
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
def
forward
(
self
,
input_ids
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
...
...
transformers/modeling_distilbert.py
View file @
1724cee8
...
@@ -421,12 +421,10 @@ class DistilBertModel(DistilBertPreTrainedModel):
...
@@ -421,12 +421,10 @@ class DistilBertModel(DistilBertPreTrainedModel):
self
.
init_weights
()
self
.
init_weights
()
@
property
def
get_input_embeddings
(
self
):
def
input_embeddings
(
self
):
return
self
.
embeddings
.
word_embeddings
return
self
.
embeddings
.
word_embeddings
@
input_embeddings
.
setter
def
set_input_embeddings
(
self
,
new_embeddings
):
def
input_embeddings
(
self
,
new_embeddings
):
self
.
embeddings
.
word_embeddings
=
new_embeddings
self
.
embeddings
.
word_embeddings
=
new_embeddings
def
_prune_heads
(
self
,
heads_to_prune
):
def
_prune_heads
(
self
,
heads_to_prune
):
...
@@ -513,8 +511,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
...
@@ -513,8 +511,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
self
.
mlm_loss_fct
=
nn
.
CrossEntropyLoss
(
ignore_index
=-
1
)
self
.
mlm_loss_fct
=
nn
.
CrossEntropyLoss
(
ignore_index
=-
1
)
@
property
def
get_output_embeddings
(
self
):
def
output_embeddings
(
self
):
return
self
.
vocab_projector
return
self
.
vocab_projector
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
head_mask
=
None
,
masked_lm_labels
=
None
):
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
head_mask
=
None
,
masked_lm_labels
=
None
):
...
...
transformers/modeling_gpt2.py
View file @
1724cee8
...
@@ -357,12 +357,10 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -357,12 +357,10 @@ class GPT2Model(GPT2PreTrainedModel):
self
.
init_weights
()
self
.
init_weights
()
@
property
def
get_input_embeddings
(
self
):
def
input_embeddings
(
self
):
return
self
.
wte
return
self
.
wte
@
input_embeddings
.
setter
def
set_input_embeddings
(
self
,
new_embeddings
):
def
input_embeddings
(
self
,
new_embeddings
):
self
.
wte
=
new_embeddings
self
.
wte
=
new_embeddings
def
_prune_heads
(
self
,
heads_to_prune
):
def
_prune_heads
(
self
,
heads_to_prune
):
...
@@ -519,8 +517,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
...
@@ -519,8 +517,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
self
.
init_weights
()
self
.
init_weights
()
@
property
def
get_output_embeddings
(
self
):
def
output_embeddings
(
self
):
return
self
.
lm_head
return
self
.
lm_head
def
forward
(
self
,
input_ids
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
def
forward
(
self
,
input_ids
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
...
@@ -623,8 +620,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
...
@@ -623,8 +620,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
self
.
init_weights
()
self
.
init_weights
()
@
property
def
get_output_embeddings
(
self
):
def
output_embeddings
(
self
):
return
self
.
lm_head
return
self
.
lm_head
def
forward
(
self
,
input_ids
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
def
forward
(
self
,
input_ids
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
...
...
transformers/modeling_openai.py
View file @
1724cee8
...
@@ -360,12 +360,10 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
...
@@ -360,12 +360,10 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
self
.
init_weights
()
self
.
init_weights
()
@
property
def
get_input_embeddings
(
self
):
def
input_embeddings
(
self
):
return
self
.
tokens_embed
return
self
.
tokens_embed
@
input_embeddings
.
setter
def
set_input_embeddings
(
self
,
new_embeddings
):
def
input_embeddings
(
self
,
new_embeddings
):
self
.
tokens_embed
=
new_embeddings
self
.
tokens_embed
=
new_embeddings
def
_prune_heads
(
self
,
heads_to_prune
):
def
_prune_heads
(
self
,
heads_to_prune
):
...
@@ -494,8 +492,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
...
@@ -494,8 +492,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
self
.
init_weights
()
self
.
init_weights
()
@
property
def
get_output_embeddings
(
self
):
def
output_embeddings
(
self
):
return
self
.
lm_head
return
self
.
lm_head
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
...
@@ -584,8 +581,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
...
@@ -584,8 +581,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
self
.
init_weights
()
self
.
init_weights
()
@
property
def
get_output_embeddings
(
self
):
def
output_embeddings
(
self
):
return
self
.
lm_head
return
self
.
lm_head
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
...
...
transformers/modeling_roberta.py
View file @
1724cee8
...
@@ -169,10 +169,11 @@ class RobertaModel(BertModel):
...
@@ -169,10 +169,11 @@ class RobertaModel(BertModel):
self
.
embeddings
=
RobertaEmbeddings
(
config
)
self
.
embeddings
=
RobertaEmbeddings
(
config
)
self
.
init_weights
()
self
.
init_weights
()
@
property
def
get_input_embeddings
(
self
):
def
input_embeddings
(
self
):
return
self
.
embeddings
.
word_embeddings
return
self
.
embeddings
.
word_embeddings
def
set_input_embeddings
(
self
,
value
):
self
.
embeddings
.
word_emebddings
=
value
@
add_start_docstrings
(
"""RoBERTa Model with a `language modeling` head on top. """
,
@
add_start_docstrings
(
"""RoBERTa Model with a `language modeling` head on top. """
,
ROBERTA_START_DOCSTRING
,
ROBERTA_INPUTS_DOCSTRING
)
ROBERTA_START_DOCSTRING
,
ROBERTA_INPUTS_DOCSTRING
)
...
@@ -218,8 +219,7 @@ class RobertaForMaskedLM(BertPreTrainedModel):
...
@@ -218,8 +219,7 @@ class RobertaForMaskedLM(BertPreTrainedModel):
self
.
init_weights
()
self
.
init_weights
()
@
property
def
get_output_embeddings
(
self
):
def
output_embeddings
(
self
):
return
self
.
lm_head
.
decoder
return
self
.
lm_head
.
decoder
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
...
...
transformers/modeling_transfo_xl.py
View file @
1724cee8
...
@@ -639,12 +639,10 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
...
@@ -639,12 +639,10 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
self
.
init_weights
()
self
.
init_weights
()
@
property
def
get_input_embeddings
(
self
):
def
input_embeddings
(
self
):
return
self
.
word_emb
return
self
.
word_emb
@
input_embeddings
.
setter
def
set_input_embeddings
(
self
,
new_embeddings
):
def
input_embeddings
(
self
,
new_embeddings
):
self
.
word_emb
=
new_embeddings
self
.
word_emb
=
new_embeddings
def
backward_compatible
(
self
):
def
backward_compatible
(
self
):
...
...
transformers/modeling_utils.py
View file @
1724cee8
...
@@ -87,21 +87,37 @@ class PreTrainedModel(nn.Module):
...
@@ -87,21 +87,37 @@ class PreTrainedModel(nn.Module):
def
base_model
(
self
):
def
base_model
(
self
):
return
getattr
(
self
,
self
.
base_model_prefix
,
self
)
return
getattr
(
self
,
self
.
base_model_prefix
,
self
)
@
property
def
get_input_embeddings
(
self
):
def
input_embeddings
(
self
):
""" Get model's input embeddings
"""
base_model
=
getattr
(
self
,
self
.
base_model_prefix
,
self
)
base_model
=
getattr
(
self
,
self
.
base_model_prefix
,
self
)
return
base_model
.
input_embeddings
if
base_model
is
not
self
:
return
base_model
.
get_input_embeddings
()
else
:
raise
NotImplementedError
@
property
def
set_input_embeddings
(
self
,
value
):
def
output_embeddings
(
self
):
""" Set model's input embeddings
"""
base_model
=
getattr
(
self
,
self
.
base_model_prefix
,
self
)
if
base_model
is
not
self
:
base_model
.
set_input_embeddings
(
value
)
else
:
raise
NotImplementedError
def
get_output_embeddings
(
self
):
""" Get model's output embeddings
Return None if the model doesn't have output embeddings
"""
return
None
# Overwrite for models with output embeddings
return
None
# Overwrite for models with output embeddings
def
tie_weights
(
self
):
def
tie_weights
(
self
):
""" Make sure we are sharing the input and output embeddings.
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
"""
if
self
.
output_embeddings
is
not
None
:
output_embeddings
=
self
.
get_output_embeddings
()
self
.
_tie_or_clone_weights
(
self
.
output_embeddings
,
self
.
input_embeddings
)
if
output_embeddings
is
not
None
:
self
.
_tie_or_clone_weights
(
output_embeddings
,
self
.
get_input_embeddings
())
def
_tie_or_clone_weights
(
self
,
output_embeddings
,
input_embeddings
):
def
_tie_or_clone_weights
(
self
,
output_embeddings
,
input_embeddings
):
""" Tie or clone module weights depending of weither we are using TorchScript or not
""" Tie or clone module weights depending of weither we are using TorchScript or not
...
@@ -150,9 +166,10 @@ class PreTrainedModel(nn.Module):
...
@@ -150,9 +166,10 @@ class PreTrainedModel(nn.Module):
return
model_embeds
return
model_embeds
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
old_embeddings
=
self
.
input_embeddings
old_embeddings
=
self
.
get_input_embeddings
()
self
.
input_embeddings
=
self
.
_get_resized_embeddings
(
old_embeddings
,
new_num_tokens
)
new_embeddings
=
self
.
_get_resized_embeddings
(
old_embeddings
,
new_num_tokens
)
return
self
.
input_embeddings
self
.
set_input_embeddings
(
new_embeddings
)
return
self
.
get_input_embeddings
()
def
_get_resized_embeddings
(
self
,
old_embeddings
,
new_num_tokens
=
None
):
def
_get_resized_embeddings
(
self
,
old_embeddings
,
new_num_tokens
=
None
):
""" Build a resized Embedding Module from a provided token Embedding Module.
""" Build a resized Embedding Module from a provided token Embedding Module.
...
...
transformers/modeling_xlm.py
View file @
1724cee8
...
@@ -407,12 +407,10 @@ class XLMModel(XLMPreTrainedModel):
...
@@ -407,12 +407,10 @@ class XLMModel(XLMPreTrainedModel):
self
.
init_weights
()
self
.
init_weights
()
@
property
def
get_input_embeddings
(
self
):
def
input_embeddings
(
self
):
return
self
.
embeddings
return
self
.
embeddings
@
input_embeddings
.
setter
def
set_input_embeddings
(
self
,
new_embeddings
):
def
input_embeddings
(
self
,
new_embeddings
):
self
.
embeddings
=
new_embeddings
self
.
embeddings
=
new_embeddings
def
_prune_heads
(
self
,
heads_to_prune
):
def
_prune_heads
(
self
,
heads_to_prune
):
...
@@ -623,8 +621,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
...
@@ -623,8 +621,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
self
.
init_weights
()
self
.
init_weights
()
@
property
def
get_output_embeddings
(
self
):
def
output_embeddings
(
self
):
return
self
.
pred_layer
.
proj
return
self
.
pred_layer
.
proj
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
langs
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
langs
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
...
...
transformers/modeling_xlnet.py
View file @
1724cee8
...
@@ -611,12 +611,10 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -611,12 +611,10 @@ class XLNetModel(XLNetPreTrainedModel):
self
.
init_weights
()
self
.
init_weights
()
@
property
def
get_input_embeddings
(
self
):
def
input_embeddings
(
self
):
return
self
.
word_embedding
return
self
.
word_embedding
@
input_embeddings
.
setter
def
set_input_embeddings
(
self
,
new_embeddings
):
def
input_embeddings
(
self
,
new_embeddings
):
self
.
word_embedding
=
new_embeddings
self
.
word_embedding
=
new_embeddings
def
_prune_heads
(
self
,
heads_to_prune
):
def
_prune_heads
(
self
,
heads_to_prune
):
...
@@ -923,8 +921,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -923,8 +921,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
self
.
init_weights
()
self
.
init_weights
()
@
property
def
get_output_embeddings
(
self
):
def
output_embeddings
(
self
):
return
self
.
lm_loss
return
self
.
lm_loss
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
...
...
transformers/tests/modeling_common_test.py
View file @
1724cee8
...
@@ -429,6 +429,12 @@ class CommonTestCases:
...
@@ -429,6 +429,12 @@ class CommonTestCases:
list
(
hidden_states
[
0
].
shape
[
-
2
:]),
list
(
hidden_states
[
0
].
shape
[
-
2
:]),
[
self
.
model_tester
.
seq_length
,
self
.
model_tester
.
hidden_size
])
[
self
.
model_tester
.
seq_length
,
self
.
model_tester
.
hidden_size
])
def
test_debug
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
model
=
model_class
(
config
)
model_embed
=
model
.
resize_token_embeddings
(
config
.
vocab_size
+
10
)
def
test_resize_tokens_embeddings
(
self
):
def
test_resize_tokens_embeddings
(
self
):
original_config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
original_config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
if
not
self
.
test_resize_embeddings
:
if
not
self
.
test_resize_embeddings
:
...
@@ -468,9 +474,9 @@ class CommonTestCases:
...
@@ -468,9 +474,9 @@ class CommonTestCases:
for
model_class
in
self
.
all_model_classes
:
for
model_class
in
self
.
all_model_classes
:
model
=
model_class
(
config
)
model
=
model_class
(
config
)
self
.
assertTrue
(
hasattr
(
model
,
'
input_embeddings
'
)
)
model
.
get_
input_embeddings
(
)
setattr
(
model
,
'
input_embeddings
'
,
torch
.
nn
.
Embedding
(
10
,
10
))
model
.
set_
input_embeddings
(
torch
.
nn
.
Embedding
(
10
,
10
))
self
.
assertTrue
(
hasattr
(
model
,
'
output_embeddings
'
)
)
model
.
get_
output_embeddings
(
)
def
test_tie_model_weights
(
self
):
def
test_tie_model_weights
(
self
):
if
not
self
.
test_torchscript
:
if
not
self
.
test_torchscript
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment