Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0eabab09
Unverified
Commit
0eabab09
authored
Sep 07, 2022
by
Joao Gante
Committed by
GitHub
Sep 07, 2022
Browse files
TF: final bias as a layer in seq2seq models (replicate TFMarian fix) (#18903)
parent
2b9513fd
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
147 additions
and
14 deletions
+147
-14
src/transformers/models/bart/modeling_tf_bart.py
src/transformers/models/bart/modeling_tf_bart.py
+20
-2
src/transformers/models/blenderbot/modeling_tf_blenderbot.py
src/transformers/models/blenderbot/modeling_tf_blenderbot.py
+21
-2
src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py
...s/models/blenderbot_small/modeling_tf_blenderbot_small.py
+21
-2
src/transformers/models/led/modeling_tf_led.py
src/transformers/models/led/modeling_tf_led.py
+21
-2
src/transformers/models/marian/modeling_tf_marian.py
src/transformers/models/marian/modeling_tf_marian.py
+1
-0
src/transformers/models/mbart/modeling_tf_mbart.py
src/transformers/models/mbart/modeling_tf_mbart.py
+21
-2
src/transformers/models/pegasus/modeling_tf_pegasus.py
src/transformers/models/pegasus/modeling_tf_pegasus.py
+21
-2
templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py
...ame}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py
+21
-2
No files found.
src/transformers/models/bart/modeling_tf_bart.py
View file @
0eabab09
...
@@ -1251,6 +1251,23 @@ class TFBartModel(TFBartPretrainedModel):
...
@@ -1251,6 +1251,23 @@ class TFBartModel(TFBartPretrainedModel):
)
)
class
BiasLayer
(
tf
.
keras
.
layers
.
Layer
):
"""
Bias as a layer. It is used for serialization purposes: `tf.keras.Model.save_weights` stores on a per-layer basis,
so all weights have to be registered in a layer.
"""
def
__init__
(
self
,
shape
,
initializer
,
trainable
,
name
,
**
kwargs
):
super
().
__init__
(
name
=
name
,
**
kwargs
)
# Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of
# "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see:
# https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214
self
.
bias
=
self
.
add_weight
(
name
=
name
,
shape
=
shape
,
initializer
=
initializer
,
trainable
=
trainable
)
def
call
(
self
,
x
):
return
x
+
self
.
bias
@
add_start_docstrings
(
@
add_start_docstrings
(
"The BART Model with a language modeling head. Can be used for summarization."
,
"The BART Model with a language modeling head. Can be used for summarization."
,
BART_START_DOCSTRING
,
BART_START_DOCSTRING
,
...
@@ -1268,9 +1285,10 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
...
@@ -1268,9 +1285,10 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
self
.
model
=
TFBartMainLayer
(
config
,
load_weight_prefix
=
load_weight_prefix
,
name
=
"model"
)
self
.
model
=
TFBartMainLayer
(
config
,
load_weight_prefix
=
load_weight_prefix
,
name
=
"model"
)
self
.
use_cache
=
config
.
use_cache
self
.
use_cache
=
config
.
use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
self
.
final_logits_bias
=
self
.
add_weight
(
self
.
bias_layer
=
BiasLayer
(
name
=
"final_logits_bias"
,
shape
=
[
1
,
config
.
vocab_size
],
initializer
=
"zeros"
,
trainable
=
False
name
=
"final_logits_bias"
,
shape
=
[
1
,
config
.
vocab_size
],
initializer
=
"zeros"
,
trainable
=
False
)
)
self
.
final_logits_bias
=
self
.
bias_layer
.
bias
# alias to keep the same interface with PT
def
get_decoder
(
self
):
def
get_decoder
(
self
):
return
self
.
model
.
decoder
return
self
.
model
.
decoder
...
@@ -1357,7 +1375,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
...
@@ -1357,7 +1375,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
training
=
training
,
training
=
training
,
)
)
lm_logits
=
self
.
model
.
shared
(
outputs
[
0
],
mode
=
"linear"
)
lm_logits
=
self
.
model
.
shared
(
outputs
[
0
],
mode
=
"linear"
)
lm_logits
=
lm_logits
+
self
.
final
_logits
_bias
lm_logits
=
self
.
bias_layer
(
lm
_logits
)
masked_lm_loss
=
None
if
labels
is
None
else
self
.
hf_compute_loss
(
labels
,
lm_logits
)
masked_lm_loss
=
None
if
labels
is
None
else
self
.
hf_compute_loss
(
labels
,
lm_logits
)
if
not
return_dict
:
if
not
return_dict
:
...
...
src/transformers/models/blenderbot/modeling_tf_blenderbot.py
View file @
0eabab09
...
@@ -1239,6 +1239,24 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
...
@@ -1239,6 +1239,24 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
)
)
# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer
class
BiasLayer
(
tf
.
keras
.
layers
.
Layer
):
"""
Bias as a layer. It is used for serialization purposes: `tf.keras.Model.save_weights` stores on a per-layer basis,
so all weights have to be registered in a layer.
"""
def
__init__
(
self
,
shape
,
initializer
,
trainable
,
name
,
**
kwargs
):
super
().
__init__
(
name
=
name
,
**
kwargs
)
# Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of
# "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see:
# https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214
self
.
bias
=
self
.
add_weight
(
name
=
name
,
shape
=
shape
,
initializer
=
initializer
,
trainable
=
trainable
)
def
call
(
self
,
x
):
return
x
+
self
.
bias
@
add_start_docstrings
(
@
add_start_docstrings
(
"The BLENDERBOT Model with a language modeling head. Can be used for summarization."
,
"The BLENDERBOT Model with a language modeling head. Can be used for summarization."
,
BLENDERBOT_START_DOCSTRING
,
BLENDERBOT_START_DOCSTRING
,
...
@@ -1254,9 +1272,10 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
...
@@ -1254,9 +1272,10 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
self
.
model
=
TFBlenderbotMainLayer
(
config
,
name
=
"model"
)
self
.
model
=
TFBlenderbotMainLayer
(
config
,
name
=
"model"
)
self
.
use_cache
=
config
.
use_cache
self
.
use_cache
=
config
.
use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
self
.
final_logits_bias
=
self
.
add_weight
(
self
.
bias_layer
=
BiasLayer
(
name
=
"final_logits_bias"
,
shape
=
[
1
,
config
.
vocab_size
],
initializer
=
"zeros"
,
trainable
=
False
name
=
"final_logits_bias"
,
shape
=
[
1
,
config
.
vocab_size
],
initializer
=
"zeros"
,
trainable
=
False
)
)
self
.
final_logits_bias
=
self
.
bias_layer
.
bias
# alias to keep the same interface with PT
def
get_decoder
(
self
):
def
get_decoder
(
self
):
return
self
.
model
.
decoder
return
self
.
model
.
decoder
...
@@ -1358,7 +1377,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
...
@@ -1358,7 +1377,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
training
=
training
,
training
=
training
,
)
)
lm_logits
=
self
.
model
.
shared
(
outputs
[
0
],
mode
=
"linear"
)
lm_logits
=
self
.
model
.
shared
(
outputs
[
0
],
mode
=
"linear"
)
lm_logits
=
lm_logits
+
self
.
final
_logits
_bias
lm_logits
=
self
.
bias_layer
(
lm
_logits
)
masked_lm_loss
=
None
if
labels
is
None
else
self
.
hf_compute_loss
(
labels
,
lm_logits
)
masked_lm_loss
=
None
if
labels
is
None
else
self
.
hf_compute_loss
(
labels
,
lm_logits
)
if
not
return_dict
:
if
not
return_dict
:
...
...
src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py
View file @
0eabab09
...
@@ -1226,6 +1226,24 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
...
@@ -1226,6 +1226,24 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
)
)
# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer
class
BiasLayer
(
tf
.
keras
.
layers
.
Layer
):
"""
Bias as a layer. It is used for serialization purposes: `tf.keras.Model.save_weights` stores on a per-layer basis,
so all weights have to be registered in a layer.
"""
def
__init__
(
self
,
shape
,
initializer
,
trainable
,
name
,
**
kwargs
):
super
().
__init__
(
name
=
name
,
**
kwargs
)
# Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of
# "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see:
# https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214
self
.
bias
=
self
.
add_weight
(
name
=
name
,
shape
=
shape
,
initializer
=
initializer
,
trainable
=
trainable
)
def
call
(
self
,
x
):
return
x
+
self
.
bias
@
add_start_docstrings
(
@
add_start_docstrings
(
"The BLENDERBOT_SMALL Model with a language modeling head. Can be used for summarization."
,
"The BLENDERBOT_SMALL Model with a language modeling head. Can be used for summarization."
,
BLENDERBOT_SMALL_START_DOCSTRING
,
BLENDERBOT_SMALL_START_DOCSTRING
,
...
@@ -1241,9 +1259,10 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
...
@@ -1241,9 +1259,10 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
self
.
model
=
TFBlenderbotSmallMainLayer
(
config
,
name
=
"model"
)
self
.
model
=
TFBlenderbotSmallMainLayer
(
config
,
name
=
"model"
)
self
.
use_cache
=
config
.
use_cache
self
.
use_cache
=
config
.
use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
self
.
final_logits_bias
=
self
.
add_weight
(
self
.
bias_layer
=
BiasLayer
(
name
=
"final_logits_bias"
,
shape
=
[
1
,
config
.
vocab_size
],
initializer
=
"zeros"
,
trainable
=
False
name
=
"final_logits_bias"
,
shape
=
[
1
,
config
.
vocab_size
],
initializer
=
"zeros"
,
trainable
=
False
)
)
self
.
final_logits_bias
=
self
.
bias_layer
.
bias
# alias to keep the same interface with PT
def
get_decoder
(
self
):
def
get_decoder
(
self
):
return
self
.
model
.
decoder
return
self
.
model
.
decoder
...
@@ -1330,7 +1349,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
...
@@ -1330,7 +1349,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
training
=
training
,
training
=
training
,
)
)
lm_logits
=
self
.
model
.
shared
(
outputs
[
0
],
mode
=
"linear"
)
lm_logits
=
self
.
model
.
shared
(
outputs
[
0
],
mode
=
"linear"
)
lm_logits
=
lm_logits
+
self
.
final
_logits
_bias
lm_logits
=
self
.
bias_layer
(
lm
_logits
)
masked_lm_loss
=
None
if
labels
is
None
else
self
.
hf_compute_loss
(
labels
,
lm_logits
)
masked_lm_loss
=
None
if
labels
is
None
else
self
.
hf_compute_loss
(
labels
,
lm_logits
)
if
not
return_dict
:
if
not
return_dict
:
...
...
src/transformers/models/led/modeling_tf_led.py
View file @
0eabab09
...
@@ -2316,6 +2316,24 @@ class TFLEDModel(TFLEDPreTrainedModel):
...
@@ -2316,6 +2316,24 @@ class TFLEDModel(TFLEDPreTrainedModel):
)
)
# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer
class
BiasLayer
(
tf
.
keras
.
layers
.
Layer
):
"""
Bias as a layer. It is used for serialization purposes: `tf.keras.Model.save_weights` stores on a per-layer basis,
so all weights have to be registered in a layer.
"""
def
__init__
(
self
,
shape
,
initializer
,
trainable
,
name
,
**
kwargs
):
super
().
__init__
(
name
=
name
,
**
kwargs
)
# Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of
# "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see:
# https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214
self
.
bias
=
self
.
add_weight
(
name
=
name
,
shape
=
shape
,
initializer
=
initializer
,
trainable
=
trainable
)
def
call
(
self
,
x
):
return
x
+
self
.
bias
@
add_start_docstrings
(
@
add_start_docstrings
(
"The LED Model with a language modeling head. Can be used for summarization."
,
"The LED Model with a language modeling head. Can be used for summarization."
,
LED_START_DOCSTRING
,
LED_START_DOCSTRING
,
...
@@ -2331,9 +2349,10 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
...
@@ -2331,9 +2349,10 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
self
.
led
=
TFLEDMainLayer
(
config
,
name
=
"led"
)
self
.
led
=
TFLEDMainLayer
(
config
,
name
=
"led"
)
self
.
use_cache
=
config
.
use_cache
self
.
use_cache
=
config
.
use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
self
.
final_logits_bias
=
self
.
add_weight
(
self
.
bias_layer
=
BiasLayer
(
name
=
"final_logits_bias"
,
shape
=
[
1
,
config
.
vocab_size
],
initializer
=
"zeros"
,
trainable
=
False
name
=
"final_logits_bias"
,
shape
=
[
1
,
config
.
vocab_size
],
initializer
=
"zeros"
,
trainable
=
False
)
)
self
.
final_logits_bias
=
self
.
bias_layer
.
bias
# alias to keep the same interface with PT
# TODO (Joao): investigate why LED has numerical issues in XLA generate
# TODO (Joao): investigate why LED has numerical issues in XLA generate
self
.
supports_xla_generation
=
False
self
.
supports_xla_generation
=
False
...
@@ -2423,7 +2442,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
...
@@ -2423,7 +2442,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
training
=
training
,
training
=
training
,
)
)
lm_logits
=
self
.
led
.
shared
(
outputs
[
0
],
mode
=
"linear"
)
lm_logits
=
self
.
led
.
shared
(
outputs
[
0
],
mode
=
"linear"
)
lm_logits
=
lm_logits
+
self
.
final
_logits
_bias
lm_logits
=
self
.
bias_layer
(
lm
_logits
)
masked_lm_loss
=
None
if
labels
is
None
else
self
.
hf_compute_loss
(
labels
,
lm_logits
)
masked_lm_loss
=
None
if
labels
is
None
else
self
.
hf_compute_loss
(
labels
,
lm_logits
)
if
not
return_dict
:
if
not
return_dict
:
...
...
src/transformers/models/marian/modeling_tf_marian.py
View file @
0eabab09
...
@@ -1269,6 +1269,7 @@ class TFMarianModel(TFMarianPreTrainedModel):
...
@@ -1269,6 +1269,7 @@ class TFMarianModel(TFMarianPreTrainedModel):
)
)
# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer
class
BiasLayer
(
tf
.
keras
.
layers
.
Layer
):
class
BiasLayer
(
tf
.
keras
.
layers
.
Layer
):
"""
"""
Bias as a layer. It is used for serialization purposes: `tf.keras.Model.save_weights` stores on a per-layer basis,
Bias as a layer. It is used for serialization purposes: `tf.keras.Model.save_weights` stores on a per-layer basis,
...
...
src/transformers/models/mbart/modeling_tf_mbart.py
View file @
0eabab09
...
@@ -1266,6 +1266,24 @@ class TFMBartModel(TFMBartPreTrainedModel):
...
@@ -1266,6 +1266,24 @@ class TFMBartModel(TFMBartPreTrainedModel):
)
)
# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer
class
BiasLayer
(
tf
.
keras
.
layers
.
Layer
):
"""
Bias as a layer. It is used for serialization purposes: `tf.keras.Model.save_weights` stores on a per-layer basis,
so all weights have to be registered in a layer.
"""
def
__init__
(
self
,
shape
,
initializer
,
trainable
,
name
,
**
kwargs
):
super
().
__init__
(
name
=
name
,
**
kwargs
)
# Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of
# "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see:
# https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214
self
.
bias
=
self
.
add_weight
(
name
=
name
,
shape
=
shape
,
initializer
=
initializer
,
trainable
=
trainable
)
def
call
(
self
,
x
):
return
x
+
self
.
bias
@
add_start_docstrings
(
@
add_start_docstrings
(
"The MBART Model with a language modeling head. Can be used for summarization."
,
"The MBART Model with a language modeling head. Can be used for summarization."
,
MBART_START_DOCSTRING
,
MBART_START_DOCSTRING
,
...
@@ -1281,9 +1299,10 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
...
@@ -1281,9 +1299,10 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
self
.
model
=
TFMBartMainLayer
(
config
,
name
=
"model"
)
self
.
model
=
TFMBartMainLayer
(
config
,
name
=
"model"
)
self
.
use_cache
=
config
.
use_cache
self
.
use_cache
=
config
.
use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
self
.
final_logits_bias
=
self
.
add_weight
(
self
.
bias_layer
=
BiasLayer
(
name
=
"final_logits_bias"
,
shape
=
[
1
,
config
.
vocab_size
],
initializer
=
"zeros"
,
trainable
=
False
name
=
"final_logits_bias"
,
shape
=
[
1
,
config
.
vocab_size
],
initializer
=
"zeros"
,
trainable
=
False
)
)
self
.
final_logits_bias
=
self
.
bias_layer
.
bias
# alias to keep the same interface with PT
def
get_decoder
(
self
):
def
get_decoder
(
self
):
return
self
.
model
.
decoder
return
self
.
model
.
decoder
...
@@ -1368,7 +1387,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
...
@@ -1368,7 +1387,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
training
=
training
,
training
=
training
,
)
)
lm_logits
=
self
.
model
.
shared
(
outputs
[
0
],
mode
=
"linear"
)
lm_logits
=
self
.
model
.
shared
(
outputs
[
0
],
mode
=
"linear"
)
lm_logits
=
lm_logits
+
self
.
final
_logits
_bias
lm_logits
=
self
.
bias_layer
(
lm
_logits
)
masked_lm_loss
=
None
if
labels
is
None
else
self
.
hf_compute_loss
(
labels
,
lm_logits
)
masked_lm_loss
=
None
if
labels
is
None
else
self
.
hf_compute_loss
(
labels
,
lm_logits
)
if
not
return_dict
:
if
not
return_dict
:
...
...
src/transformers/models/pegasus/modeling_tf_pegasus.py
View file @
0eabab09
...
@@ -1278,6 +1278,24 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
...
@@ -1278,6 +1278,24 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
)
)
# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer
class
BiasLayer
(
tf
.
keras
.
layers
.
Layer
):
"""
Bias as a layer. It is used for serialization purposes: `tf.keras.Model.save_weights` stores on a per-layer basis,
so all weights have to be registered in a layer.
"""
def
__init__
(
self
,
shape
,
initializer
,
trainable
,
name
,
**
kwargs
):
super
().
__init__
(
name
=
name
,
**
kwargs
)
# Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of
# "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see:
# https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214
self
.
bias
=
self
.
add_weight
(
name
=
name
,
shape
=
shape
,
initializer
=
initializer
,
trainable
=
trainable
)
def
call
(
self
,
x
):
return
x
+
self
.
bias
@
add_start_docstrings
(
@
add_start_docstrings
(
"The PEGASUS Model with a language modeling head. Can be used for summarization."
,
"The PEGASUS Model with a language modeling head. Can be used for summarization."
,
PEGASUS_START_DOCSTRING
,
PEGASUS_START_DOCSTRING
,
...
@@ -1293,9 +1311,10 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
...
@@ -1293,9 +1311,10 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
self
.
model
=
TFPegasusMainLayer
(
config
,
name
=
"model"
)
self
.
model
=
TFPegasusMainLayer
(
config
,
name
=
"model"
)
self
.
use_cache
=
config
.
use_cache
self
.
use_cache
=
config
.
use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
self
.
final_logits_bias
=
self
.
add_weight
(
self
.
bias_layer
=
BiasLayer
(
name
=
"final_logits_bias"
,
shape
=
[
1
,
config
.
vocab_size
],
initializer
=
"zeros"
,
trainable
=
False
name
=
"final_logits_bias"
,
shape
=
[
1
,
config
.
vocab_size
],
initializer
=
"zeros"
,
trainable
=
False
)
)
self
.
final_logits_bias
=
self
.
bias_layer
.
bias
# alias to keep the same interface with PT
def
get_decoder
(
self
):
def
get_decoder
(
self
):
return
self
.
model
.
decoder
return
self
.
model
.
decoder
...
@@ -1382,7 +1401,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
...
@@ -1382,7 +1401,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
training
=
training
,
training
=
training
,
)
)
lm_logits
=
self
.
model
.
shared
(
outputs
[
0
],
mode
=
"linear"
)
lm_logits
=
self
.
model
.
shared
(
outputs
[
0
],
mode
=
"linear"
)
lm_logits
=
lm_logits
+
self
.
final
_logits
_bias
lm_logits
=
self
.
bias_layer
(
lm
_logits
)
masked_lm_loss
=
None
if
labels
is
None
else
self
.
hf_compute_loss
(
labels
,
lm_logits
)
masked_lm_loss
=
None
if
labels
is
None
else
self
.
hf_compute_loss
(
labels
,
lm_logits
)
if
not
return_dict
:
if
not
return_dict
:
...
...
templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py
View file @
0eabab09
...
@@ -2806,6 +2806,24 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
...
@@ -2806,6 +2806,24 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
)
)
# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer
class
BiasLayer
(
tf
.
keras
.
layers
.
Layer
):
"""
Bias as a layer. It is used for serialization purposes: `tf.keras.Model.save_weights` stores on a per-layer basis,
so all weights have to be registered in a layer.
"""
def
__init__
(
self
,
shape
,
initializer
,
trainable
,
name
,
**
kwargs
):
super
().
__init__
(
name
=
name
,
**
kwargs
)
# Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of
# "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see:
# https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214
self
.
bias
=
self
.
add_weight
(
name
=
name
,
shape
=
shape
,
initializer
=
initializer
,
trainable
=
trainable
)
def
call
(
self
,
x
):
return
x
+
self
.
bias
@
add_start_docstrings
(
@
add_start_docstrings
(
"The {{cookiecutter.uppercase_modelname}} Model with a language modeling head. Can be used for summarization."
,
"The {{cookiecutter.uppercase_modelname}} Model with a language modeling head. Can be used for summarization."
,
{{
cookiecutter
.
uppercase_modelname
}}
_START_DOCSTRING
,
{{
cookiecutter
.
uppercase_modelname
}}
_START_DOCSTRING
,
...
@@ -2822,9 +2840,10 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
...
@@ -2822,9 +2840,10 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
self
.
model
.
_set_save_spec
(
inputs
=
self
.
serving
.
input_signature
)
self
.
model
.
_set_save_spec
(
inputs
=
self
.
serving
.
input_signature
)
self
.
use_cache
=
config
.
use_cache
self
.
use_cache
=
config
.
use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
self
.
final_logits_bias
=
self
.
add_weight
(
self
.
bias_layer
=
BiasLayer
(
name
=
"final_logits_bias"
,
shape
=
[
1
,
config
.
vocab_size
],
initializer
=
"zeros"
,
trainable
=
False
name
=
"final_logits_bias"
,
shape
=
[
1
,
config
.
vocab_size
],
initializer
=
"zeros"
,
trainable
=
False
)
)
self
.
final_logits_bias
=
self
.
bias_layer
.
bias
# alias to keep the same interface with PT
def
get_decoder
(
self
):
def
get_decoder
(
self
):
return
self
.
model
.
decoder
return
self
.
model
.
decoder
...
@@ -2911,7 +2930,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
...
@@ -2911,7 +2930,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
training
=
training
training
=
training
)
)
lm_logits
=
self
.
model
.
shared
(
outputs
[
0
],
mode
=
"linear"
)
lm_logits
=
self
.
model
.
shared
(
outputs
[
0
],
mode
=
"linear"
)
lm_logits
=
lm_logits
+
self
.
final
_logits
_bias
lm_logits
=
self
.
bias_layer
(
lm
_logits
)
masked_lm_loss
=
None
if
labels
is
None
else
self
.
hf_compute_loss
(
labels
,
lm_logits
)
masked_lm_loss
=
None
if
labels
is
None
else
self
.
hf_compute_loss
(
labels
,
lm_logits
)
if
not
return_dict
:
if
not
return_dict
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment