test_modeling_flax_roberta.py 2.78 KB
Newer Older
1
2
3
4
import unittest

from numpy import ndarray

Sylvain Gugger's avatar
Sylvain Gugger committed
5
from transformers import RobertaTokenizerFast, TensorType, is_flax_available, is_torch_available
6
7
8
9
from transformers.testing_utils import require_flax, require_torch


if is_flax_available():
10
11
12
13
14
    import os

    os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12"  # assumed parallelism: 8

    import jax
Sylvain Gugger's avatar
Sylvain Gugger committed
15
    from transformers.models.roberta.modeling_flax_roberta import FlaxRobertaModel
16
17
18
19

if is_torch_available():
    import torch

Sylvain Gugger's avatar
Sylvain Gugger committed
20
    from transformers.models.roberta.modeling_roberta import RobertaModel
21
22
23
24
25


@require_flax
@require_torch
class FlaxRobertaModelTest(unittest.TestCase):
26
27
28
29
    def assert_almost_equals(self, a: ndarray, b: ndarray, tol: float):
        diff = (a - b).sum()
        self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol})")

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
    def test_from_pytorch(self):
        with torch.no_grad():
            with self.subTest("roberta-base"):
                tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
                fx_model = FlaxRobertaModel.from_pretrained("roberta-base")
                pt_model = RobertaModel.from_pretrained("roberta-base")

                # Check for simple input
                pt_inputs = tokenizer.encode_plus("This is a simple input", return_tensors=TensorType.PYTORCH)
                fx_inputs = tokenizer.encode_plus("This is a simple input", return_tensors=TensorType.JAX)
                pt_outputs = pt_model(**pt_inputs)
                fx_outputs = fx_model(**fx_inputs)

                self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")

45
                for fx_output, pt_output in zip(fx_outputs, pt_outputs.to_tuple()):
46
                    self.assert_almost_equals(fx_output, pt_output.numpy(), 6e-4)
47

48
49
50
    def test_multiple_sequences(self):
        tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
        model = FlaxRobertaModel.from_pretrained("roberta-base")
51

52
53
        sequences = ["this is an example sentence", "this is another", "and a third one"]
        encodings = tokenizer(sequences, return_tensors=TensorType.JAX, padding=True, truncation=True)
54

55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
        @jax.jit
        def model_jitted(input_ids, attention_mask=None, token_type_ids=None):
            return model(input_ids, attention_mask, token_type_ids)

        with self.subTest("JIT Disabled"):
            with jax.disable_jit():
                tokens, pooled = model_jitted(**encodings)
                self.assertEqual(tokens.shape, (3, 7, 768))
                self.assertEqual(pooled.shape, (3, 768))

        with self.subTest("JIT Enabled"):
            jitted_tokens, jitted_pooled = model_jitted(**encodings)

            self.assertEqual(jitted_tokens.shape, (3, 7, 768))
            self.assertEqual(jitted_pooled.shape, (3, 768))