Unverified Commit 1486205d authored by sandip's avatar sandip Committed by GitHub
Browse files

TF DistilBERT integration tests (#9975)

* TF DistilBERT integration test

* Update test_modeling_tf_distilbert.py
parent f2d5c04e
......@@ -221,3 +221,26 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
for model_name in list(TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]):
model = TFDistilBertModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@require_tf
class TFDistilBertModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_masked_lm(self):
model = TFDistilBertModel.from_pretrained("distilbert-base-uncased")
input_ids = tf.constant([[0, 1, 2, 3, 4, 5]])
output = model(input_ids)[0]
expected_shape = [1, 6, 768]
self.assertEqual(output.shape, expected_shape)
expected_slice = tf.constant(
[
[
[0.19261885, -0.13732955, 0.4119799],
[0.22150156, -0.07422661, 0.39037204],
[0.22756018, -0.0896414, 0.3701467],
]
]
)
tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=1e-4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment