Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3060899b
Unverified
Commit
3060899b
authored
Dec 14, 2023
by
Matt
Committed by
GitHub
Dec 14, 2023
Browse files
Replace build() with build_in_name_scope() for some TF tests (#28046)
Replace build() with build_in_name_scope() for some tests
parent
050e0b44
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
10 additions
and
10 deletions
+10
-10
tests/models/bart/test_modeling_tf_bart.py
tests/models/bart/test_modeling_tf_bart.py
+1
-1
tests/models/ctrl/test_modeling_tf_ctrl.py
tests/models/ctrl/test_modeling_tf_ctrl.py
+1
-1
tests/test_modeling_tf_common.py
tests/test_modeling_tf_common.py
+4
-4
tests/test_modeling_tf_utils.py
tests/test_modeling_tf_utils.py
+4
-4
No files found.
tests/models/bart/test_modeling_tf_bart.py
View file @
3060899b
...
...
@@ -304,7 +304,7 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, PipelineTester
old_total_size
=
config
.
vocab_size
new_total_size
=
old_total_size
+
new_tokens_size
model
=
model_class
(
config
=
copy
.
deepcopy
(
config
))
# `resize_token_embeddings` mutates `config`
model
.
build
()
model
.
build
_in_name_scope
()
model
.
resize_token_embeddings
(
new_total_size
)
# fetch the output for an input exclusively made of new members of the vocabulary
...
...
tests/models/ctrl/test_modeling_tf_ctrl.py
View file @
3060899b
...
...
@@ -225,7 +225,7 @@ class TFCTRLModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase
for
model_class
in
self
.
all_model_classes
:
model
=
model_class
(
config
)
model
.
build
()
# may be needed for the get_bias() call below
model
.
build
_in_name_scope
()
# may be needed for the get_bias() call below
assert
isinstance
(
model
.
get_input_embeddings
(),
tf
.
keras
.
layers
.
Layer
)
if
model_class
in
list_lm_models
:
...
...
tests/test_modeling_tf_common.py
View file @
3060899b
...
...
@@ -316,7 +316,7 @@ class TFModelTesterMixin:
with
tf
.
Graph
().
as_default
()
as
g
:
model
=
model_class
(
config
)
model
.
build
()
model
.
build
_in_name_scope
()
for
op
in
g
.
get_operations
():
model_op_names
.
add
(
op
.
node_def
.
op
)
...
...
@@ -346,7 +346,7 @@ class TFModelTesterMixin:
for
model_class
in
self
.
all_model_classes
[:
2
]:
model
=
model_class
(
config
)
model
.
build
()
model
.
build
_in_name_scope
()
onnx_model_proto
,
_
=
tf2onnx
.
convert
.
from_keras
(
model
,
opset
=
self
.
onnx_min_opset
)
...
...
@@ -1088,7 +1088,7 @@ class TFModelTesterMixin:
def
_get_word_embedding_weight
(
model
,
embedding_layer
):
if
isinstance
(
embedding_layer
,
tf
.
keras
.
layers
.
Embedding
):
# builds the embeddings layer
model
.
build
()
model
.
build
_in_name_scope
()
return
embedding_layer
.
embeddings
else
:
return
model
.
_get_word_embedding_weight
(
embedding_layer
)
...
...
@@ -1151,7 +1151,7 @@ class TFModelTesterMixin:
old_total_size
=
config
.
vocab_size
new_total_size
=
old_total_size
+
new_tokens_size
model
=
model_class
(
config
=
copy
.
deepcopy
(
config
))
# `resize_token_embeddings` mutates `config`
model
.
build
()
model
.
build
_in_name_scope
()
model
.
resize_token_embeddings
(
new_total_size
)
# fetch the output for an input exclusively made of new members of the vocabulary
...
...
tests/test_modeling_tf_utils.py
View file @
3060899b
...
...
@@ -402,8 +402,8 @@ class TFModelUtilsTest(unittest.TestCase):
# Finally, check the model can be reloaded
new_model
=
TFBertModel
.
from_pretrained
(
tmp_dir
)
model
.
build
()
new_model
.
build
()
model
.
build
_in_name_scope
()
new_model
.
build
_in_name_scope
()
for
p1
,
p2
in
zip
(
model
.
weights
,
new_model
.
weights
):
self
.
assertTrue
(
np
.
allclose
(
p1
.
numpy
(),
p2
.
numpy
()))
...
...
@@ -632,7 +632,7 @@ class TFModelPushToHubTester(unittest.TestCase):
)
model
=
TFBertModel
(
config
)
# Make sure model is properly initialized
model
.
build
()
model
.
build
_in_name_scope
()
logging
.
set_verbosity_info
()
logger
=
logging
.
get_logger
(
"transformers.utils.hub"
)
...
...
@@ -701,7 +701,7 @@ class TFModelPushToHubTester(unittest.TestCase):
)
model
=
TFBertModel
(
config
)
# Make sure model is properly initialized
model
.
build
()
model
.
build
_in_name_scope
()
model
.
push_to_hub
(
"valid_org/test-model-tf-org"
,
token
=
self
.
_token
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment