"vscode:/vscode.git/clone" did not exist on "38ddab10da90e64297a37c0719ed9309e693317a"
Unverified Commit 72b19ca6 authored by Matt's avatar Matt Committed by GitHub
Browse files

Fix ESM checkpoints for tests (#20436)

* Re-enable TF ESM tests, make sure we use facebook checkpoints

* make fixup
parent f244a978
......@@ -274,7 +274,7 @@ class EsmModelIntegrationTest(TestCasePlus):
@slow
def test_inference_masked_lm(self):
with torch.no_grad():
model = EsmForMaskedLM.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")
model.eval()
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
output = model(input_ids)[0]
......@@ -292,7 +292,7 @@ class EsmModelIntegrationTest(TestCasePlus):
@slow
def test_inference_no_head(self):
with torch.no_grad():
model = EsmModel.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
model = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D")
model.eval()
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
......
......@@ -247,7 +247,7 @@ class EsmFoldModelTest(ModelTesterMixin, unittest.TestCase):
class EsmModelIntegrationTest(TestCasePlus):
@slow
def test_inference_protein_folding(self):
model = EsmForProteinFolding.from_pretrained("Rocketknight1/esmfold_v1").float()
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1").float()
model.eval()
input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
position_outputs = model(input_ids)["positions"]
......
......@@ -254,9 +254,9 @@ class TFEsmModelTest(TFModelTesterMixin, unittest.TestCase):
@require_tf
class TFEsmModelIntegrationTest(unittest.TestCase):
@unittest.skip("Temporarily disabled as we update ESM model checkpoints")
@slow
def test_inference_masked_lm(self):
model = TFEsmForMaskedLM.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
model = TFEsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")
input_ids = tf.constant([[0, 1, 2, 3, 4, 5]])
output = model(input_ids)[0]
......@@ -264,13 +264,19 @@ class TFEsmModelIntegrationTest(unittest.TestCase):
self.assertEqual(list(output.numpy().shape), expected_shape)
# compare the actual values for a slice.
expected_slice = tf.constant(
[[[15.0963, -6.6414, -1.1346], [-0.2209, -9.9633, 4.2082], [-1.6045, -10.0011, 1.5882]]]
[
[
[8.920963, -10.591399, -6.467397],
[-6.3980846, -13.913257, -1.1291938],
[-7.7815733, -13.951929, -3.7438734],
]
]
)
self.assertTrue(numpy.allclose(output[:, :3, :3].numpy(), expected_slice.numpy(), atol=1e-4))
@unittest.skip("Temporarily disabled as we update ESM model checkpoints")
@slow
def test_inference_no_head(self):
model = TFEsmModel.from_pretrained("Rocketknight1/esm2_t6_8M_UR50D")
model = TFEsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D")
input_ids = tf.constant([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
output = model(input_ids)[0]
......@@ -278,9 +284,9 @@ class TFEsmModelIntegrationTest(unittest.TestCase):
expected_slice = tf.constant(
[
[
[0.144337, 0.541198, 0.32479298],
[0.30328932, 0.00519154, 0.31089523],
[0.32273883, -0.24992886, 0.34143737],
[0.14422388, 0.5411936, 0.3249576],
[0.30342406, 0.00549317, 0.31096306],
[0.32278833, -0.24974644, 0.34135976],
]
]
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment