test_activations_tf.py 694 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import unittest

from transformers import is_tf_available
from transformers.testing_utils import require_tf


if is_tf_available():
    from transformers.activations_tf import get_tf_activation


@require_tf
class TestTFActivations(unittest.TestCase):
    def test_get_activation(self):
        get_tf_activation("swish")
        get_tf_activation("gelu")
        get_tf_activation("relu")
        get_tf_activation("tanh")
        get_tf_activation("gelu_new")
        get_tf_activation("gelu_fast")
        get_tf_activation("mish")
        with self.assertRaises(KeyError):
            get_tf_activation("bogus")
        with self.assertRaises(KeyError):
            get_tf_activation(None)