autograd_impl.py 2.87 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
from typing import Callable, Tuple
import torch
from torch import Tensor
from torch.autograd import gradcheck
from torchaudio_unittest.common_utils import (
    TestBaseMixin,
)
from torchaudio.prototype.rnnt_loss import RNNTLoss, rnnt_loss
from parameterized import parameterized
from .utils import (
    get_B1_T10_U3_D4_data,
12
13
    get_B2_T4_U3_D3_data,
    get_B1_T2_U3_D5_data
14
15
16
17
18
19
20
)
from .numpy_transducer import NumpyTransducerLoss


class Autograd(TestBaseMixin):
    @staticmethod
    def get_data(data_func, device):
21
22
23
        data = data_func()
        if type(data) == tuple:
            data = data[0]
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
        return data

    def assert_grad(
            self,
            loss: Callable[..., Tensor],
            inputs: Tuple[torch.Tensor],
            *,
            enable_all_grad: bool = True,
    ):
        inputs_ = []
        for i in inputs:
            if torch.is_tensor(i):
                i = i.to(dtype=self.dtype, device=self.device)
                if enable_all_grad:
                    i.requires_grad = True
            inputs_.append(i)
        # gradcheck with float32 requires higher atol and epsilon
        assert gradcheck(loss, inputs, eps=1e-3, atol=1e-3, nondet_tol=0.)

    @parameterized.expand([
        (get_B1_T10_U3_D4_data, ),
45
46
        (get_B2_T4_U3_D3_data, ),
        (get_B1_T2_U3_D5_data, ),
47
48
49
50
51
52
53
54
55
    ])
    def test_RNNTLoss_gradcheck(self, data_func):
        data = self.get_data(data_func, self.device)
        inputs = (
            data["logits"].to(self.dtype),
            data["targets"],
            data["logit_lengths"],
            data["target_lengths"],
        )
56
        loss = RNNTLoss(blank=data["blank"])
57
58
59
60
61

        self.assert_grad(loss, inputs, enable_all_grad=False)

    @parameterized.expand([
        (get_B1_T10_U3_D4_data, ),
62
63
        (get_B2_T4_U3_D3_data, ),
        (get_B1_T2_U3_D5_data, ),
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    ])
    def test_rnnt_loss_gradcheck(self, data_func):
        data = self.get_data(data_func, self.device)
        inputs = (
            data["logits"].to(self.dtype),  # logits
            data["targets"],                # targets
            data["logit_lengths"],          # logit_lengths
            data["target_lengths"],         # target_lengths
            data["blank"],                  # blank
            -1,                             # clamp
        )

        self.assert_grad(rnnt_loss, inputs, enable_all_grad=False)

    @parameterized.expand([
        (get_B1_T10_U3_D4_data, ),
80
81
        (get_B2_T4_U3_D3_data, ),
        (get_B1_T2_U3_D5_data, ),
82
83
84
85
86
87
88
89
90
91
92
93
    ])
    def test_np_transducer_gradcheck(self, data_func):
        data = self.get_data(data_func, self.device)
        inputs = (
            data["logits"].to(self.dtype),
            data["logit_lengths"],
            data["target_lengths"],
            data["targets"],
        )
        loss = NumpyTransducerLoss(blank=data["blank"])

        self.assert_grad(loss, inputs, enable_all_grad=False)