Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6dc0a849
Unverified
Commit
6dc0a849
authored
Apr 20, 2023
by
Matt
Committed by
GitHub
Apr 20, 2023
Browse files
Fix weight tying in TF-ESM (#22839)
Fix weight tying in ESM
parent
3b61d289
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
8 deletions
+35
-8
src/transformers/models/esm/modeling_tf_esm.py
src/transformers/models/esm/modeling_tf_esm.py
+17
-8
tests/models/esm/test_modeling_tf_esm.py
tests/models/esm/test_modeling_tf_esm.py
+18
-0
No files found.
src/transformers/models/esm/modeling_tf_esm.py
View file @
6dc0a849
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
# limitations under the License.
# limitations under the License.
""" PyTorch ESM model."""
""" PyTorch ESM model."""
import
os
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -1102,6 +1103,11 @@ class TFEsmForMaskedLM(TFEsmPreTrainedModel, TFMaskedLanguageModelingLoss):
...
@@ -1102,6 +1103,11 @@ class TFEsmForMaskedLM(TFEsmPreTrainedModel, TFMaskedLanguageModelingLoss):
self
.
esm
=
TFEsmMainLayer
(
config
,
add_pooling_layer
=
False
,
name
=
"esm"
)
self
.
esm
=
TFEsmMainLayer
(
config
,
add_pooling_layer
=
False
,
name
=
"esm"
)
self
.
lm_head
=
TFEsmLMHead
(
config
,
name
=
"lm_head"
)
self
.
lm_head
=
TFEsmLMHead
(
config
,
name
=
"lm_head"
)
if
config
.
tie_word_embeddings
:
# Ensure word embeddings are built so that we actually have something to tie
with
tf
.
name_scope
(
os
.
path
.
join
(
self
.
_name_scope
(),
"esm"
,
"embeddings"
,
"word_embeddings"
)):
self
.
esm
.
embeddings
.
word_embeddings
.
build
((
None
,
None
))
self
.
lm_head
.
decoder
=
self
.
esm
.
embeddings
.
word_embeddings
.
weights
[
0
]
def
get_output_embeddings
(
self
):
def
get_output_embeddings
(
self
):
return
self
.
lm_head
.
decoder
return
self
.
lm_head
.
decoder
...
@@ -1211,18 +1217,22 @@ class TFEsmLMHead(Layer):
...
@@ -1211,18 +1217,22 @@ class TFEsmLMHead(Layer):
self
.
layer_norm
=
LayerNormalization
(
epsilon
=
config
.
layer_norm_eps
,
name
=
"layer_norm"
)
self
.
layer_norm
=
LayerNormalization
(
epsilon
=
config
.
layer_norm_eps
,
name
=
"layer_norm"
)
self
.
decoder
=
Dense
(
self
.
decoder
=
None
config
.
vocab_size
,
use_bias
=
False
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"decoder"
,
)
self
.
config
=
config
self
.
config
=
config
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
):
super
().
build
(
input_shape
)
super
().
build
(
input_shape
)
# Separate bias to match the PT model and allow weight cross-loading to work
# Separate bias to match the PT model and allow weight cross-loading to work
# Put it in the build so it gets the right name when adding it as a weight
# Put it in the build so it gets the right name when adding it as a weight
if
not
self
.
config
.
tie_word_embeddings
:
if
self
.
decoder
is
not
None
:
raise
ValueError
(
"Expected decoder not to be initialized before build when not tying weights!"
)
self
.
decoder
=
self
.
add_weight
(
"decoder.weight"
,
shape
=
(
self
.
config
.
hidden_size
,
self
.
config
.
vocab_size
),
initializer
=
get_initializer
(
self
.
config
.
initializer_range
),
trainable
=
True
,
)
self
.
bias
=
self
.
add_weight
(
"bias"
,
shape
=
(
self
.
config
.
vocab_size
,),
initializer
=
"zeros"
,
trainable
=
True
)
self
.
bias
=
self
.
add_weight
(
"bias"
,
shape
=
(
self
.
config
.
vocab_size
,),
initializer
=
"zeros"
,
trainable
=
True
)
def
get_bias
(
self
):
def
get_bias
(
self
):
...
@@ -1234,8 +1244,7 @@ class TFEsmLMHead(Layer):
...
@@ -1234,8 +1244,7 @@ class TFEsmLMHead(Layer):
x
=
self
.
layer_norm
(
x
)
x
=
self
.
layer_norm
(
x
)
# project back to size of vocabulary with bias
# project back to size of vocabulary with bias
x
=
self
.
decoder
(
x
)
x
=
tf
.
matmul
(
x
,
self
.
decoder
,
transpose_b
=
True
)
+
self
.
bias
x
=
x
+
self
.
bias
return
x
return
x
...
...
tests/models/esm/test_modeling_tf_esm.py
View file @
6dc0a849
...
@@ -262,6 +262,24 @@ class TFEsmModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
...
@@ -262,6 +262,24 @@ class TFEsmModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
def
test_save_load_after_resize_token_embeddings
(
self
):
def
test_save_load_after_resize_token_embeddings
(
self
):
pass
pass
def
test_model_common_attributes
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
model
=
model_class
(
config
)
assert
isinstance
(
model
.
get_input_embeddings
(),
tf
.
keras
.
layers
.
Layer
)
if
model_class
is
TFEsmForMaskedLM
:
# Output embedding test differs from the main test because they're a matrix, not a layer
name
=
model
.
get_bias
()
assert
isinstance
(
name
,
dict
)
for
k
,
v
in
name
.
items
():
assert
isinstance
(
v
,
tf
.
Variable
)
else
:
x
=
model
.
get_output_embeddings
()
assert
x
is
None
name
=
model
.
get_bias
()
assert
name
is
None
@
require_tf
@
require_tf
class
TFEsmModelIntegrationTest
(
unittest
.
TestCase
):
class
TFEsmModelIntegrationTest
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment