Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4107445a
"tests/test_modeling_tf_deberta.py" did not exist on "bf7f79cd5720ea74e85d96320ec2b71d5b138589"
Unverified
Commit
4107445a
authored
Oct 10, 2022
by
Matt
Committed by
GitHub
Oct 10, 2022
Browse files
Fix repo names for ESM tests (#19451)
parent
cbb8a379
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
3 deletions
+3
-3
tests/models/esm/test_modeling_esm.py
tests/models/esm/test_modeling_esm.py
+3
-3
No files found.
tests/models/esm/test_modeling_esm.py
View file @
4107445a
...
@@ -245,7 +245,7 @@ class EsmModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
...
@@ -245,7 +245,7 @@ class EsmModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
class
EsmModelIntegrationTest
(
TestCasePlus
):
class
EsmModelIntegrationTest
(
TestCasePlus
):
@
slow
@
slow
def
test_inference_masked_lm
(
self
):
def
test_inference_masked_lm
(
self
):
model
=
EsmForMaskedLM
.
from_pretrained
(
"Rocketknight1/esm
-2-8m
"
)
model
=
EsmForMaskedLM
.
from_pretrained
(
"Rocketknight1/esm
2_t6_8M_UR50D
"
)
input_ids
=
torch
.
tensor
([[
0
,
1
,
2
,
3
,
4
,
5
]])
input_ids
=
torch
.
tensor
([[
0
,
1
,
2
,
3
,
4
,
5
]])
output
=
model
(
input_ids
)[
0
]
output
=
model
(
input_ids
)[
0
]
...
@@ -261,7 +261,7 @@ class EsmModelIntegrationTest(TestCasePlus):
...
@@ -261,7 +261,7 @@ class EsmModelIntegrationTest(TestCasePlus):
@
slow
@
slow
def
test_inference_no_head
(
self
):
def
test_inference_no_head
(
self
):
model
=
EsmModel
.
from_pretrained
(
"Rocketknight1/esm
-2-8m
"
)
model
=
EsmModel
.
from_pretrained
(
"Rocketknight1/esm
2_t6_8M_UR50D
"
)
input_ids
=
torch
.
tensor
([[
0
,
6
,
4
,
13
,
5
,
4
,
16
,
12
,
11
,
7
,
2
]])
input_ids
=
torch
.
tensor
([[
0
,
6
,
4
,
13
,
5
,
4
,
16
,
12
,
11
,
7
,
2
]])
output
=
model
(
input_ids
)[
0
]
output
=
model
(
input_ids
)[
0
]
...
@@ -276,7 +276,7 @@ class EsmModelIntegrationTest(TestCasePlus):
...
@@ -276,7 +276,7 @@ class EsmModelIntegrationTest(TestCasePlus):
keys_to_ignore_on_save_tied
=
[
r
"lm_head.decoder.weight"
,
r
"lm_head.decoder.bias"
]
keys_to_ignore_on_save_tied
=
[
r
"lm_head.decoder.weight"
,
r
"lm_head.decoder.bias"
]
keys_to_ignore_on_save_untied
=
[
r
"lm_head.decoder.bias"
]
keys_to_ignore_on_save_untied
=
[
r
"lm_head.decoder.bias"
]
config
=
EsmConfig
.
from_pretrained
(
"Rocketknight1/esm
-2-8m
"
)
config
=
EsmConfig
.
from_pretrained
(
"Rocketknight1/esm
2_t6_8M_UR50D
"
)
config_tied
=
deepcopy
(
config
)
config_tied
=
deepcopy
(
config
)
config_tied
.
tie_word_embeddings
=
True
config_tied
.
tie_word_embeddings
=
True
config_untied
=
deepcopy
(
config
)
config_untied
=
deepcopy
(
config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment