Unverified Commit 0ea6d10d authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Refactor RNNT Loss Unit Tests (#1630)

parent 56ab0368
...@@ -8,10 +8,9 @@ from torchaudio_unittest.common_utils import ( ...@@ -8,10 +8,9 @@ from torchaudio_unittest.common_utils import (
from torchaudio.prototype.rnnt_loss import RNNTLoss, rnnt_loss from torchaudio.prototype.rnnt_loss import RNNTLoss, rnnt_loss
from parameterized import parameterized from parameterized import parameterized
from .utils import ( from .utils import (
numpy_to_torch,
get_B1_T10_U3_D4_data, get_B1_T10_U3_D4_data,
get_numpy_data_B2_T4_U3_D3, get_B2_T4_U3_D3_data,
get_numpy_data_B1_T2_U3_D5 get_B1_T2_U3_D5_data
) )
from .numpy_transducer import NumpyTransducerLoss from .numpy_transducer import NumpyTransducerLoss
...@@ -19,12 +18,9 @@ from .numpy_transducer import NumpyTransducerLoss ...@@ -19,12 +18,9 @@ from .numpy_transducer import NumpyTransducerLoss
class Autograd(TestBaseMixin): class Autograd(TestBaseMixin):
@staticmethod @staticmethod
def get_data(data_func, device): def get_data(data_func, device):
data_np = data_func() data = data_func()
if type(data_np) == tuple: if type(data) == tuple:
data_np = data_np[0] data = data[0]
data = numpy_to_torch(
data=data_np, device=device, requires_grad=True
)
return data return data
def assert_grad( def assert_grad(
...@@ -46,8 +42,8 @@ class Autograd(TestBaseMixin): ...@@ -46,8 +42,8 @@ class Autograd(TestBaseMixin):
@parameterized.expand([ @parameterized.expand([
(get_B1_T10_U3_D4_data, ), (get_B1_T10_U3_D4_data, ),
(get_numpy_data_B2_T4_U3_D3, ), (get_B2_T4_U3_D3_data, ),
(get_numpy_data_B1_T2_U3_D5, ), (get_B1_T2_U3_D5_data, ),
]) ])
def test_RNNTLoss_gradcheck(self, data_func): def test_RNNTLoss_gradcheck(self, data_func):
data = self.get_data(data_func, self.device) data = self.get_data(data_func, self.device)
...@@ -63,8 +59,8 @@ class Autograd(TestBaseMixin): ...@@ -63,8 +59,8 @@ class Autograd(TestBaseMixin):
@parameterized.expand([ @parameterized.expand([
(get_B1_T10_U3_D4_data, ), (get_B1_T10_U3_D4_data, ),
(get_numpy_data_B2_T4_U3_D3, ), (get_B2_T4_U3_D3_data, ),
(get_numpy_data_B1_T2_U3_D5, ), (get_B1_T2_U3_D5_data, ),
]) ])
def test_rnnt_loss_gradcheck(self, data_func): def test_rnnt_loss_gradcheck(self, data_func):
data = self.get_data(data_func, self.device) data = self.get_data(data_func, self.device)
...@@ -83,8 +79,8 @@ class Autograd(TestBaseMixin): ...@@ -83,8 +79,8 @@ class Autograd(TestBaseMixin):
@parameterized.expand([ @parameterized.expand([
(get_B1_T10_U3_D4_data, ), (get_B1_T10_U3_D4_data, ),
(get_numpy_data_B2_T4_U3_D3, ), (get_B2_T4_U3_D3_data, ),
(get_numpy_data_B1_T2_U3_D5, ), (get_B1_T2_U3_D5_data, ),
]) ])
def test_np_transducer_gradcheck(self, data_func): def test_np_transducer_gradcheck(self, data_func):
data = self.get_data(data_func, self.device) data = self.get_data(data_func, self.device)
......
import numpy as np import torch
from torchaudio.prototype.rnnt_loss import RNNTLoss from torchaudio.prototype.rnnt_loss import RNNTLoss
from .utils import ( from .utils import (
compute_with_numpy_transducer, compute_with_numpy_transducer,
compute_with_pytorch_transducer, compute_with_pytorch_transducer,
get_basic_data,
get_B1_T10_U3_D4_data, get_B1_T10_U3_D4_data,
get_data_basic, get_B1_T2_U3_D5_data,
get_numpy_data_B1_T2_U3_D5, get_B2_T4_U3_D3_data,
get_numpy_data_B2_T4_U3_D3, get_random_data,
get_numpy_random_data,
numpy_to_torch,
) )
...@@ -23,42 +22,30 @@ class RNNTLossTest: ...@@ -23,42 +22,30 @@ class RNNTLossTest:
costs, gradients = compute_with_pytorch_transducer( costs, gradients = compute_with_pytorch_transducer(
data=data, reuse_logits_for_grads=reuse_logits_for_grads data=data, reuse_logits_for_grads=reuse_logits_for_grads
) )
np.testing.assert_allclose(costs, ref_costs, atol=atol, rtol=rtol) self.assertEqual(costs, ref_costs, atol=atol, rtol=rtol)
self.assertEqual(logits_shape, gradients.shape) self.assertEqual(logits_shape, gradients.shape)
if not np.allclose(gradients, ref_gradients, atol=atol, rtol=rtol): self.assertEqual(gradients, ref_gradients, atol=atol, rtol=rtol)
for b in range(len(gradients)):
T = data["logit_lengths"][b]
U = data["target_lengths"][b]
for t in range(gradients.shape[1]):
for u in range(gradients.shape[2]):
np.testing.assert_allclose(
gradients[b, t, u],
ref_gradients[b, t, u],
atol=atol,
rtol=rtol,
err_msg=f"failed on b={b}, t={t}/T={T}, u={u}/U={U}",
)
def test_basic_backward(self): def test_basic_backward(self):
rnnt_loss = RNNTLoss() rnnt_loss = RNNTLoss()
logits, targets, logit_lengths, target_lengths = get_data_basic(self.device) logits, targets, logit_lengths, target_lengths = get_basic_data(self.device)
loss = rnnt_loss(logits, targets, logit_lengths, target_lengths) loss = rnnt_loss(logits, targets, logit_lengths, target_lengths)
loss.backward() loss.backward()
def test_costs_and_gradients_B1_T2_U3_D5_fp32(self): def test_costs_and_gradients_B1_T2_U3_D5_fp32(self):
data, ref_costs, ref_gradients = get_numpy_data_B1_T2_U3_D5( data, ref_costs, ref_gradients = get_B1_T2_U3_D5_data(
dtype=np.float32 dtype=torch.float32,
device=self.device,
) )
data = numpy_to_torch(data=data, device=self.device, requires_grad=True)
self._test_costs_and_gradients( self._test_costs_and_gradients(
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
) )
def test_costs_and_gradients_B1_T2_U3_D5_fp16(self): def test_costs_and_gradients_B1_T2_U3_D5_fp16(self):
data, ref_costs, ref_gradients = get_numpy_data_B1_T2_U3_D5( data, ref_costs, ref_gradients = get_B1_T2_U3_D5_data(
dtype=np.float16 dtype=torch.float16,
device=self.device,
) )
data = numpy_to_torch(data=data, device=self.device, requires_grad=True)
self._test_costs_and_gradients( self._test_costs_and_gradients(
data=data, data=data,
ref_costs=ref_costs, ref_costs=ref_costs,
...@@ -68,19 +55,19 @@ class RNNTLossTest: ...@@ -68,19 +55,19 @@ class RNNTLossTest:
) )
def test_costs_and_gradients_B2_T4_U3_D3_fp32(self): def test_costs_and_gradients_B2_T4_U3_D3_fp32(self):
data, ref_costs, ref_gradients = get_numpy_data_B2_T4_U3_D3( data, ref_costs, ref_gradients = get_B2_T4_U3_D3_data(
dtype=np.float32 dtype=torch.float32,
device=self.device,
) )
data = numpy_to_torch(data=data, device=self.device, requires_grad=True)
self._test_costs_and_gradients( self._test_costs_and_gradients(
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
) )
def test_costs_and_gradients_B2_T4_U3_D3_fp16(self): def test_costs_and_gradients_B2_T4_U3_D3_fp16(self):
data, ref_costs, ref_gradients = get_numpy_data_B2_T4_U3_D3( data, ref_costs, ref_gradients = get_B2_T4_U3_D3_data(
dtype=np.float16 dtype=torch.float16,
device=self.device,
) )
data = numpy_to_torch(data=data, device=self.device, requires_grad=True)
self._test_costs_and_gradients( self._test_costs_and_gradients(
data=data, data=data,
ref_costs=ref_costs, ref_costs=ref_costs,
...@@ -92,8 +79,7 @@ class RNNTLossTest: ...@@ -92,8 +79,7 @@ class RNNTLossTest:
def test_costs_and_gradients_random_data_with_numpy_fp32(self): def test_costs_and_gradients_random_data_with_numpy_fp32(self):
seed = 777 seed = 777
for i in range(5): for i in range(5):
data = get_numpy_random_data(dtype=np.float32, seed=(seed + i)) data = get_random_data(dtype=torch.float32, device=self.device, seed=(seed + i))
data = numpy_to_torch(data=data, device=self.device, requires_grad=True)
ref_costs, ref_gradients = compute_with_numpy_transducer(data=data) ref_costs, ref_gradients = compute_with_numpy_transducer(data=data)
self._test_costs_and_gradients( self._test_costs_and_gradients(
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
...@@ -103,9 +89,8 @@ class RNNTLossTest: ...@@ -103,9 +89,8 @@ class RNNTLossTest:
for random in [False, True]: for random in [False, True]:
data = get_B1_T10_U3_D4_data( data = get_B1_T10_U3_D4_data(
random=random, random=random,
) dtype=torch.float32,
data = numpy_to_torch( device=self.device,
data=data, device=self.device, requires_grad=True
) )
data["fused_log_softmax"] = False data["fused_log_softmax"] = False
ref_costs, ref_gradients = compute_with_numpy_transducer( ref_costs, ref_gradients = compute_with_numpy_transducer(
......
import unittest import unittest
import random
import numpy as np
import torch import torch
from torchaudio.prototype.rnnt_loss import RNNTLoss from torchaudio.prototype.rnnt_loss import RNNTLoss
...@@ -19,10 +18,8 @@ def compute_with_numpy_transducer(data): ...@@ -19,10 +18,8 @@ def compute_with_numpy_transducer(data):
loss = torch.sum(costs) loss = torch.sum(costs)
loss.backward() loss.backward()
costs = costs.cpu()
costs = costs.cpu().data.numpy() gradients = data["logits"].saved_grad.cpu()
gradients = data["logits"].saved_grad.cpu().data.numpy()
return costs, gradients return costs, gradients
...@@ -41,12 +38,12 @@ def compute_with_pytorch_transducer(data, reuse_logits_for_grads=False): ...@@ -41,12 +38,12 @@ def compute_with_pytorch_transducer(data, reuse_logits_for_grads=False):
loss = torch.sum(costs) loss = torch.sum(costs)
loss.backward() loss.backward()
costs = costs.cpu().data.numpy() costs = costs.cpu()
gradients = data["logits"].saved_grad.cpu().data.numpy() gradients = data["logits"].saved_grad.cpu()
return costs, gradients return costs, gradients
def get_data_basic(device): def get_basic_data(device):
# Example provided # Example provided
# in 6f73a2513dc784c59eec153a45f40bc528355b18 # in 6f73a2513dc784c59eec153a45f40bc528355b18
# of https://github.com/HawkAaron/warp-transducer # of https://github.com/HawkAaron/warp-transducer
...@@ -66,16 +63,12 @@ def get_data_basic(device): ...@@ -66,16 +63,12 @@ def get_data_basic(device):
], ],
] ]
], ],
dtype=torch.float, dtype=torch.float32,
device=device,
) )
targets = torch.tensor([[1, 2]], dtype=torch.int) targets = torch.tensor([[1, 2]], dtype=torch.int, device=device)
logit_lengths = torch.tensor([2], dtype=torch.int) logit_lengths = torch.tensor([2], dtype=torch.int, device=device)
target_lengths = torch.tensor([2], dtype=torch.int) target_lengths = torch.tensor([2], dtype=torch.int, device=device)
logits = logits.to(device=device)
targets = targets.to(device=device)
logit_lengths = logit_lengths.to(device=device)
target_lengths = target_lengths.to(device=device)
logits.requires_grad_(True) logits.requires_grad_(True)
...@@ -84,27 +77,32 @@ def get_data_basic(device): ...@@ -84,27 +77,32 @@ def get_data_basic(device):
def get_B1_T10_U3_D4_data( def get_B1_T10_U3_D4_data(
random=False, random=False,
dtype=np.float32, dtype=torch.float32,
nan=False, device=torch.device("cpu"),
): ):
B, T, U, D = 2, 10, 3, 4 B, T, U, D = 2, 10, 3, 4
data = {}
data["logits"] = np.random.rand(B, T, U, D).astype(dtype) logits = torch.rand(B, T, U, D, dtype=dtype, device=device)
if not random: if not random:
data["logits"].fill(0.1) logits.fill_(0.1)
if nan: logits.requires_grad_(True)
for i in range(B):
data["logits"][i][0][0][0] = np.nan def grad_hook(grad):
data["logit_lengths"] = np.array([10, 10], dtype=np.int32) logits.saved_grad = grad.clone()
data["target_lengths"] = np.array([2, 2], dtype=np.int32) logits.register_hook(grad_hook)
data["targets"] = np.array([[1, 2], [1, 2]], dtype=np.int32)
data = {}
data["logits"] = logits
data["logit_lengths"] = torch.tensor([10, 10], dtype=torch.int32, device=device)
data["target_lengths"] = torch.tensor([2, 2], dtype=torch.int32, device=device)
data["targets"] = torch.tensor([[1, 2], [1, 2]], dtype=torch.int32, device=device)
data["blank"] = 0 data["blank"] = 0
return data return data
def get_numpy_data_B1_T2_U3_D5(dtype=np.float32): def get_B1_T2_U3_D5_data(dtype=torch.float32, device=torch.device("cpu")):
logits = np.array( logits = torch.tensor(
[ [
0.1, 0.1,
0.6, 0.6,
...@@ -138,15 +136,22 @@ def get_numpy_data_B1_T2_U3_D5(dtype=np.float32): ...@@ -138,15 +136,22 @@ def get_numpy_data_B1_T2_U3_D5(dtype=np.float32):
0.1, 0.1,
], ],
dtype=dtype, dtype=dtype,
device=device,
).reshape(1, 2, 3, 5) ).reshape(1, 2, 3, 5)
targets = np.array([[1, 2]], dtype=np.int32) logits.requires_grad_(True)
logit_lengths = np.array([2], dtype=np.int32)
target_lengths = np.array([2], dtype=np.int32) def grad_hook(grad):
logits.saved_grad = grad.clone()
logits.register_hook(grad_hook)
targets = torch.tensor([[1, 2]], dtype=torch.int32, device=device)
logit_lengths = torch.tensor([2], dtype=torch.int32, device=device)
target_lengths = torch.tensor([2], dtype=torch.int32, device=device)
blank = -1 blank = -1
ref_costs = np.array([5.09566688538], dtype=dtype) ref_costs = torch.tensor([5.09566688538], dtype=dtype)
ref_gradients = np.array( ref_gradients = torch.tensor(
[ [
0.17703132, 0.17703132,
-0.39992708, -0.39992708,
...@@ -193,10 +198,9 @@ def get_numpy_data_B1_T2_U3_D5(dtype=np.float32): ...@@ -193,10 +198,9 @@ def get_numpy_data_B1_T2_U3_D5(dtype=np.float32):
return data, ref_costs, ref_gradients return data, ref_costs, ref_gradients
def get_numpy_data_B2_T4_U3_D3(dtype=np.float32): def get_B2_T4_U3_D3_data(dtype=torch.float32, device=torch.device("cpu")):
# Test from D21322854 # Test from D21322854
logits = torch.tensor(
logits = np.array(
[ [
0.065357, 0.065357,
0.787530, 0.787530,
...@@ -272,17 +276,23 @@ def get_numpy_data_B2_T4_U3_D3(dtype=np.float32): ...@@ -272,17 +276,23 @@ def get_numpy_data_B2_T4_U3_D3(dtype=np.float32):
0.358021, 0.358021,
], ],
dtype=dtype, dtype=dtype,
device=device,
).reshape(2, 4, 3, 3) ).reshape(2, 4, 3, 3)
logits.requires_grad_(True)
targets = np.array([[1, 2], [1, 1]], dtype=np.int32) def grad_hook(grad):
logit_lengths = np.array([4, 4], dtype=np.int32) logits.saved_grad = grad.clone()
target_lengths = np.array([2, 2], dtype=np.int32) logits.register_hook(grad_hook)
targets = torch.tensor([[1, 2], [1, 1]], dtype=torch.int32, device=device)
logit_lengths = torch.tensor([4, 4], dtype=torch.int32, device=device)
target_lengths = torch.tensor([2, 2], dtype=torch.int32, device=device)
blank = 0 blank = 0
ref_costs = np.array([4.2806528590890736, 3.9384369822503591], dtype=dtype) ref_costs = torch.tensor([4.2806528590890736, 3.9384369822503591], dtype=dtype)
ref_gradients = np.array( ref_gradients = torch.tensor(
[ [
-0.186844, -0.186844,
-0.062555, -0.062555,
...@@ -371,30 +381,45 @@ def get_numpy_data_B2_T4_U3_D3(dtype=np.float32): ...@@ -371,30 +381,45 @@ def get_numpy_data_B2_T4_U3_D3(dtype=np.float32):
return data, ref_costs, ref_gradients return data, ref_costs, ref_gradients
def get_numpy_random_data( def get_random_data(
max_B=8, max_T=128, max_U=32, max_D=40, blank=-1, dtype=np.float32, seed=None max_B=8,
max_T=128,
max_U=32,
max_D=40,
blank=-1,
dtype=torch.float32,
device=torch.device("cpu"),
seed=None,
): ):
if seed is not None: if seed is not None:
np.random.seed(seed=seed) torch.manual_seed(seed=seed)
if blank != -1: if blank != -1:
raise ValueError("blank != -1 is not supported yet.") raise ValueError("blank != -1 is not supported yet.")
B = np.random.randint(low=1, high=max_B) random.seed(0)
T = np.random.randint(low=5, high=max_T) B = random.randint(1, max_B - 1)
U = np.random.randint(low=5, high=max_U) T = random.randint(5, max_T - 1)
D = np.random.randint(low=2, high=max_D) U = random.randint(5, max_U - 1)
D = random.randint(2, max_D - 1)
logit_lengths = np.random.randint(low=5, high=T + 1, size=(B,), dtype=np.int32)
target_lengths = np.random.randint(low=5, high=U + 1, size=(B,), dtype=np.int32) logit_lengths = torch.randint(low=5, high=T + 1, size=(B,), dtype=torch.int32, device=device)
max_src_length = np.max(logit_lengths) target_lengths = torch.randint(low=5, high=U + 1, size=(B,), dtype=torch.int32, device=device)
max_tgt_length = np.max(target_lengths) max_src_length = torch.max(logit_lengths)
targets = np.random.randint( max_tgt_length = torch.max(target_lengths)
low=0, high=D - 1, size=(B, max_tgt_length), dtype=np.int32
targets = torch.randint(
low=0, high=D - 1, size=(B, max_tgt_length), dtype=torch.int32, device=device
) )
logits = np.random.random_sample( logits = torch.rand(
size=(B, max_src_length, max_tgt_length + 1, D) size=(B, max_src_length, max_tgt_length + 1, D),
).astype(dtype=dtype) dtype=dtype,
device=device,
).requires_grad_(True)
def grad_hook(grad):
logits.saved_grad = grad.clone()
logits.register_hook(grad_hook)
return { return {
"logits": logits, "logits": logits,
...@@ -405,44 +430,6 @@ def get_numpy_random_data( ...@@ -405,44 +430,6 @@ def get_numpy_random_data(
} }
def numpy_to_torch(data, device, requires_grad=True):
logits = torch.from_numpy(data["logits"]).to(device=device)
targets = torch.from_numpy(data["targets"]).to(device=device)
logit_lengths = torch.from_numpy(data["logit_lengths"]).to(device=device)
target_lengths = torch.from_numpy(data["target_lengths"]).to(device=device)
if "nbest_wers" in data:
data["nbest_wers"] = torch.from_numpy(data["nbest_wers"]).to(device=device)
if "nbest_scores" in data:
data["nbest_scores"] = torch.from_numpy(data["nbest_scores"]).to(
device=device
)
logits = torch.autograd.Variable(logits, requires_grad=requires_grad)
logit_lengths = torch.autograd.Variable(logit_lengths)
target_lengths = torch.autograd.Variable(target_lengths)
targets = torch.autograd.Variable(targets)
if device == torch.device("cpu"):
logits = logits.cpu()
elif device == torch.device("cuda"):
logits = logits.cuda()
else:
raise ValueError("unrecognized device = {}".format(device))
def grad_hook(grad):
logits.saved_grad = grad.clone()
logits.register_hook(grad_hook)
data["logits"] = logits
data["logit_lengths"] = logit_lengths
data["target_lengths"] = target_lengths
data["targets"] = targets
return data
def skipIfNoRNNT(test_item): def skipIfNoRNNT(test_item):
try: try:
torch.ops.torchaudio.rnnt_loss torch.ops.torchaudio.rnnt_loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment