Unverified Commit 75fd00fb authored by sandip's avatar sandip Committed by GitHub
Browse files

Integration test added for TF MPnet (#9979)

parent ce08043f
...@@ -240,3 +240,26 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -240,3 +240,26 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase):
for model_name in ["microsoft/mpnet-base"]: for model_name in ["microsoft/mpnet-base"]:
model = TFMPNetModel.from_pretrained(model_name) model = TFMPNetModel.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
@require_tf
class TFMPNetModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_masked_lm(self):
model = TFMPNetModel.from_pretrained("microsoft/mpnet-base")
input_ids = tf.constant([[0, 1, 2, 3, 4, 5]])
output = model(input_ids)[0]
expected_shape = [1, 6, 768]
self.assertEqual(output.shape, expected_shape)
expected_slice = tf.constant(
[
[
[-0.1067172, 0.08216473, 0.0024543],
[-0.03465879, 0.8354118, -0.03252288],
[-0.06569476, -0.12424111, -0.0494436],
]
]
)
tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=1e-4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment