Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6dc0a849
"...lm-evaluation-harness.git" did not exist on "9b9d8dd7652eee05d911f5d94598ea02fb01bdfe"
Unverified
Commit
6dc0a849
authored
Apr 20, 2023
by
Matt
Committed by
GitHub
Apr 20, 2023
Browse files
Fix weight tying in TF-ESM (#22839)
Fix weight tying in ESM
parent
3b61d289
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
8 deletions
+35
-8
src/transformers/models/esm/modeling_tf_esm.py
src/transformers/models/esm/modeling_tf_esm.py
+17
-8
tests/models/esm/test_modeling_tf_esm.py
tests/models/esm/test_modeling_tf_esm.py
+18
-0
No files found.
src/transformers/models/esm/modeling_tf_esm.py
View file @
6dc0a849
...
...
@@ -14,6 +14,7 @@
# limitations under the License.
""" PyTorch ESM model."""
import
os
from
typing
import
Optional
,
Tuple
,
Union
import
numpy
as
np
...
...
@@ -1102,6 +1103,11 @@ class TFEsmForMaskedLM(TFEsmPreTrainedModel, TFMaskedLanguageModelingLoss):
self
.
esm
=
TFEsmMainLayer
(
config
,
add_pooling_layer
=
False
,
name
=
"esm"
)
self
.
lm_head
=
TFEsmLMHead
(
config
,
name
=
"lm_head"
)
if
config
.
tie_word_embeddings
:
# Ensure word embeddings are built so that we actually have something to tie
with
tf
.
name_scope
(
os
.
path
.
join
(
self
.
_name_scope
(),
"esm"
,
"embeddings"
,
"word_embeddings"
)):
self
.
esm
.
embeddings
.
word_embeddings
.
build
((
None
,
None
))
self
.
lm_head
.
decoder
=
self
.
esm
.
embeddings
.
word_embeddings
.
weights
[
0
]
def
get_output_embeddings
(
self
):
return
self
.
lm_head
.
decoder
...
...
@@ -1211,18 +1217,22 @@ class TFEsmLMHead(Layer):
self
.
layer_norm
=
LayerNormalization
(
epsilon
=
config
.
layer_norm_eps
,
name
=
"layer_norm"
)
self
.
decoder
=
Dense
(
config
.
vocab_size
,
use_bias
=
False
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"decoder"
,
)
self
.
decoder
=
None
self
.
config
=
config
def
build
(
self
,
input_shape
):
super
().
build
(
input_shape
)
# Separate bias to match the PT model and allow weight cross-loading to work
# Put it in the build so it gets the right name when adding it as a weight
if
not
self
.
config
.
tie_word_embeddings
:
if
self
.
decoder
is
not
None
:
raise
ValueError
(
"Expected decoder not to be initialized before build when not tying weights!"
)
self
.
decoder
=
self
.
add_weight
(
"decoder.weight"
,
shape
=
(
self
.
config
.
hidden_size
,
self
.
config
.
vocab_size
),
initializer
=
get_initializer
(
self
.
config
.
initializer_range
),
trainable
=
True
,
)
self
.
bias
=
self
.
add_weight
(
"bias"
,
shape
=
(
self
.
config
.
vocab_size
,),
initializer
=
"zeros"
,
trainable
=
True
)
def
get_bias
(
self
):
...
...
@@ -1234,8 +1244,7 @@ class TFEsmLMHead(Layer):
x
=
self
.
layer_norm
(
x
)
# project back to size of vocabulary with bias
x
=
self
.
decoder
(
x
)
x
=
x
+
self
.
bias
x
=
tf
.
matmul
(
x
,
self
.
decoder
,
transpose_b
=
True
)
+
self
.
bias
return
x
...
...
tests/models/esm/test_modeling_tf_esm.py
View file @
6dc0a849
...
...
@@ -262,6 +262,24 @@ class TFEsmModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
def
test_save_load_after_resize_token_embeddings
(
self
):
pass
def
test_model_common_attributes
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
model
=
model_class
(
config
)
assert
isinstance
(
model
.
get_input_embeddings
(),
tf
.
keras
.
layers
.
Layer
)
if
model_class
is
TFEsmForMaskedLM
:
# Output embedding test differs from the main test because they're a matrix, not a layer
name
=
model
.
get_bias
()
assert
isinstance
(
name
,
dict
)
for
k
,
v
in
name
.
items
():
assert
isinstance
(
v
,
tf
.
Variable
)
else
:
x
=
model
.
get_output_embeddings
()
assert
x
is
None
name
=
model
.
get_bias
()
assert
name
is
None
@
require_tf
class
TFEsmModelIntegrationTest
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment