"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "92f2fbad508f0f4640e91d5de67601e64e8bd2f3"
Unverified Commit 263fac71 authored by sandip's avatar sandip Committed by GitHub
Browse files

Integration test for electra model (#10073)

parent 781220ac
...@@ -344,3 +344,19 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -344,3 +344,19 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
for model_name in ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = ElectraModel.from_pretrained(model_name) model = ElectraModel.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
@require_torch
class ElectraModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_no_head_absolute_embedding(self):
model = ElectraForPreTraining.from_pretrained("google/electra-small-discriminator")
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
output = model(input_ids)[0]
expected_shape = torch.Size((1, 11))
self.assertEqual(output.shape, expected_shape)
expected_slice = torch.tensor(
[[-8.9253, -4.0305, -3.9306, -3.8774, -4.1873, -4.1280, 0.9429, -4.1672, 0.9281, 0.0410, -3.4823]]
)
self.assertTrue(torch.allclose(output, expected_slice, atol=1e-4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment