"backend/apps/vscode:/vscode.git/clone" did not exist on "1d7ca44017717f21c0d0315e236c46e412c765d1"
test_optimization_tf.py 3.51 KB
Newer Older
1
2
3
import unittest

from transformers import is_tf_available
4
from transformers.testing_utils import require_tf
5

Aymeric Augustin's avatar
Aymeric Augustin committed
6

7
8
9
10
if is_tf_available():
    import tensorflow as tf
    from tensorflow.python.eager import context
    from tensorflow.python.framework import ops
11
    from transformers import create_optimizer, GradientAccumulator
12

13
14

@require_tf
15
16
17
18
19
class OptimizationFTest(unittest.TestCase):
    def assertListAlmostEqual(self, list1, list2, tol):
        self.assertEqual(len(list1), len(list2))
        for a, b in zip(list1, list2):
            self.assertAlmostEqual(a, b, delta=tol)
20

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
    def testGradientAccumulator(self):
        accumulator = GradientAccumulator()
        accumulator([tf.constant([1.0, 2.0])])
        accumulator([tf.constant([-2.0, 1.0])])
        accumulator([tf.constant([-1.0, 2.0])])
        with self.assertRaises(ValueError):
            accumulator([tf.constant([1.0, 1.0]), tf.constant([2.0, 2.0])])
        self.assertEqual(accumulator.step, 3)
        self.assertEqual(len(accumulator.gradients), 1)
        self.assertListAlmostEqual(accumulator.gradients[0].numpy().tolist(), [-2.0, 5.0], tol=1e-2)
        accumulator.reset()
        self.assertEqual(accumulator.step, 0)
        self.assertListAlmostEqual(accumulator.gradients[0].numpy().tolist(), [0.0, 0.0], tol=1e-2)

    def testGradientAccumulatorDistributionStrategy(self):
        context._context = None
        ops.enable_eager_execution_internal()
Julien Plu's avatar
Julien Plu committed
38
39
40
41
42
43
44
        physical_devices = tf.config.list_physical_devices("CPU")
        if len(physical_devices) == 1:
            tf.config.set_logical_device_configuration(
                physical_devices[0], [tf.config.LogicalDeviceConfiguration(), tf.config.LogicalDeviceConfiguration()]
            )
        devices = tf.config.list_logical_devices(device_type="CPU")
        strategy = tf.distribute.MirroredStrategy(devices=devices[:2])
45
46
47
48

        with strategy.scope():
            accumulator = GradientAccumulator()
            variable = tf.Variable([4.0, 3.0])
Julien Plu's avatar
Julien Plu committed
49
            optimizer, _ = create_optimizer(5e-5, 10, 5)
50
51
52
53
54
55
            gradient_placeholder = tf.Variable([0.0, 0.0], trainable=False)

        def accumulate_on_replica(gradient):
            accumulator([gradient])

        def apply_on_replica():
Julien Plu's avatar
Julien Plu committed
56
            optimizer.apply_gradients(list(zip(accumulator.gradients, [variable])))
57
58
59
60

        @tf.function
        def accumulate(grad1, grad2):
            with strategy.scope():
Julien Plu's avatar
Julien Plu committed
61
62
63
                local_variables = strategy.experimental_local_results(gradient_placeholder)
                local_variables[0].assign(grad1)
                local_variables[1].assign(grad2)
64
65
66
67
68
69
70
                strategy.experimental_run_v2(accumulate_on_replica, args=(gradient_placeholder,))

        @tf.function
        def apply_grad():
            with strategy.scope():
                strategy.experimental_run_v2(apply_on_replica)

Julien Plu's avatar
Julien Plu committed
71
72
73
74
75
        def _check_local_values(grad1, grad2):
            values = strategy.experimental_local_results(accumulator._gradients[0])
            self.assertListAlmostEqual(values[0].value(), grad1, tol=1e-2)
            self.assertListAlmostEqual(values[1].value(), grad2, tol=1e-2)

76
77
78
79
        accumulate([1.0, 2.0], [-1.0, 1.0])
        accumulate([3.0, -1.0], [-1.0, -1.0])
        accumulate([-2.0, 2.0], [3.0, -2.0])
        self.assertEqual(accumulator.step, 3)
Julien Plu's avatar
Julien Plu committed
80
        _check_local_values([2.0, 3.0], [1.0, -2.0])
81
        apply_grad()
Julien Plu's avatar
Julien Plu committed
82
        self.assertListAlmostEqual(variable.value(), [4.0, 3.0], tol=1e-2)
83
84
        accumulator.reset()
        self.assertEqual(accumulator.step, 0)
Julien Plu's avatar
Julien Plu committed
85
        _check_local_values([0.0, 0.0], [0.0, 0.0])