"git@developer.sourcefind.cn:OpenDAS/torch-spline-conv.git" did not exist on "916b9a5e65fc176275df89aeb460897a4d450232"
Unverified Commit 0ea6d10d authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Refactor RNNT Loss Unit Tests (#1630)

parent 56ab0368
......@@ -8,10 +8,9 @@ from torchaudio_unittest.common_utils import (
from torchaudio.prototype.rnnt_loss import RNNTLoss, rnnt_loss
from parameterized import parameterized
from .utils import (
numpy_to_torch,
get_B1_T10_U3_D4_data,
get_numpy_data_B2_T4_U3_D3,
get_numpy_data_B1_T2_U3_D5
get_B2_T4_U3_D3_data,
get_B1_T2_U3_D5_data
)
from .numpy_transducer import NumpyTransducerLoss
......@@ -19,12 +18,9 @@ from .numpy_transducer import NumpyTransducerLoss
class Autograd(TestBaseMixin):
@staticmethod
def get_data(data_func, device):
data_np = data_func()
if type(data_np) == tuple:
data_np = data_np[0]
data = numpy_to_torch(
data=data_np, device=device, requires_grad=True
)
data = data_func()
if type(data) == tuple:
data = data[0]
return data
def assert_grad(
......@@ -46,8 +42,8 @@ class Autograd(TestBaseMixin):
@parameterized.expand([
(get_B1_T10_U3_D4_data, ),
(get_numpy_data_B2_T4_U3_D3, ),
(get_numpy_data_B1_T2_U3_D5, ),
(get_B2_T4_U3_D3_data, ),
(get_B1_T2_U3_D5_data, ),
])
def test_RNNTLoss_gradcheck(self, data_func):
data = self.get_data(data_func, self.device)
......@@ -63,8 +59,8 @@ class Autograd(TestBaseMixin):
@parameterized.expand([
(get_B1_T10_U3_D4_data, ),
(get_numpy_data_B2_T4_U3_D3, ),
(get_numpy_data_B1_T2_U3_D5, ),
(get_B2_T4_U3_D3_data, ),
(get_B1_T2_U3_D5_data, ),
])
def test_rnnt_loss_gradcheck(self, data_func):
data = self.get_data(data_func, self.device)
......@@ -83,8 +79,8 @@ class Autograd(TestBaseMixin):
@parameterized.expand([
(get_B1_T10_U3_D4_data, ),
(get_numpy_data_B2_T4_U3_D3, ),
(get_numpy_data_B1_T2_U3_D5, ),
(get_B2_T4_U3_D3_data, ),
(get_B1_T2_U3_D5_data, ),
])
def test_np_transducer_gradcheck(self, data_func):
data = self.get_data(data_func, self.device)
......
import numpy as np
import torch
from torchaudio.prototype.rnnt_loss import RNNTLoss
from .utils import (
compute_with_numpy_transducer,
compute_with_pytorch_transducer,
get_basic_data,
get_B1_T10_U3_D4_data,
get_data_basic,
get_numpy_data_B1_T2_U3_D5,
get_numpy_data_B2_T4_U3_D3,
get_numpy_random_data,
numpy_to_torch,
get_B1_T2_U3_D5_data,
get_B2_T4_U3_D3_data,
get_random_data,
)
......@@ -23,42 +22,30 @@ class RNNTLossTest:
costs, gradients = compute_with_pytorch_transducer(
data=data, reuse_logits_for_grads=reuse_logits_for_grads
)
np.testing.assert_allclose(costs, ref_costs, atol=atol, rtol=rtol)
self.assertEqual(costs, ref_costs, atol=atol, rtol=rtol)
self.assertEqual(logits_shape, gradients.shape)
if not np.allclose(gradients, ref_gradients, atol=atol, rtol=rtol):
for b in range(len(gradients)):
T = data["logit_lengths"][b]
U = data["target_lengths"][b]
for t in range(gradients.shape[1]):
for u in range(gradients.shape[2]):
np.testing.assert_allclose(
gradients[b, t, u],
ref_gradients[b, t, u],
atol=atol,
rtol=rtol,
err_msg=f"failed on b={b}, t={t}/T={T}, u={u}/U={U}",
)
self.assertEqual(gradients, ref_gradients, atol=atol, rtol=rtol)
def test_basic_backward(self):
rnnt_loss = RNNTLoss()
logits, targets, logit_lengths, target_lengths = get_data_basic(self.device)
logits, targets, logit_lengths, target_lengths = get_basic_data(self.device)
loss = rnnt_loss(logits, targets, logit_lengths, target_lengths)
loss.backward()
def test_costs_and_gradients_B1_T2_U3_D5_fp32(self):
data, ref_costs, ref_gradients = get_numpy_data_B1_T2_U3_D5(
dtype=np.float32
data, ref_costs, ref_gradients = get_B1_T2_U3_D5_data(
dtype=torch.float32,
device=self.device,
)
data = numpy_to_torch(data=data, device=self.device, requires_grad=True)
self._test_costs_and_gradients(
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
)
def test_costs_and_gradients_B1_T2_U3_D5_fp16(self):
data, ref_costs, ref_gradients = get_numpy_data_B1_T2_U3_D5(
dtype=np.float16
data, ref_costs, ref_gradients = get_B1_T2_U3_D5_data(
dtype=torch.float16,
device=self.device,
)
data = numpy_to_torch(data=data, device=self.device, requires_grad=True)
self._test_costs_and_gradients(
data=data,
ref_costs=ref_costs,
......@@ -68,19 +55,19 @@ class RNNTLossTest:
)
def test_costs_and_gradients_B2_T4_U3_D3_fp32(self):
data, ref_costs, ref_gradients = get_numpy_data_B2_T4_U3_D3(
dtype=np.float32
data, ref_costs, ref_gradients = get_B2_T4_U3_D3_data(
dtype=torch.float32,
device=self.device,
)
data = numpy_to_torch(data=data, device=self.device, requires_grad=True)
self._test_costs_and_gradients(
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
)
def test_costs_and_gradients_B2_T4_U3_D3_fp16(self):
data, ref_costs, ref_gradients = get_numpy_data_B2_T4_U3_D3(
dtype=np.float16
data, ref_costs, ref_gradients = get_B2_T4_U3_D3_data(
dtype=torch.float16,
device=self.device,
)
data = numpy_to_torch(data=data, device=self.device, requires_grad=True)
self._test_costs_and_gradients(
data=data,
ref_costs=ref_costs,
......@@ -92,8 +79,7 @@ class RNNTLossTest:
def test_costs_and_gradients_random_data_with_numpy_fp32(self):
seed = 777
for i in range(5):
data = get_numpy_random_data(dtype=np.float32, seed=(seed + i))
data = numpy_to_torch(data=data, device=self.device, requires_grad=True)
data = get_random_data(dtype=torch.float32, device=self.device, seed=(seed + i))
ref_costs, ref_gradients = compute_with_numpy_transducer(data=data)
self._test_costs_and_gradients(
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
......@@ -103,9 +89,8 @@ class RNNTLossTest:
for random in [False, True]:
data = get_B1_T10_U3_D4_data(
random=random,
)
data = numpy_to_torch(
data=data, device=self.device, requires_grad=True
dtype=torch.float32,
device=self.device,
)
data["fused_log_softmax"] = False
ref_costs, ref_gradients = compute_with_numpy_transducer(
......
import unittest
import numpy as np
import random
import torch
from torchaudio.prototype.rnnt_loss import RNNTLoss
......@@ -19,10 +18,8 @@ def compute_with_numpy_transducer(data):
loss = torch.sum(costs)
loss.backward()
costs = costs.cpu().data.numpy()
gradients = data["logits"].saved_grad.cpu().data.numpy()
costs = costs.cpu()
gradients = data["logits"].saved_grad.cpu()
return costs, gradients
......@@ -41,12 +38,12 @@ def compute_with_pytorch_transducer(data, reuse_logits_for_grads=False):
loss = torch.sum(costs)
loss.backward()
costs = costs.cpu().data.numpy()
gradients = data["logits"].saved_grad.cpu().data.numpy()
costs = costs.cpu()
gradients = data["logits"].saved_grad.cpu()
return costs, gradients
def get_data_basic(device):
def get_basic_data(device):
# Example provided
# in 6f73a2513dc784c59eec153a45f40bc528355b18
# of https://github.com/HawkAaron/warp-transducer
......@@ -66,16 +63,12 @@ def get_data_basic(device):
],
]
],
dtype=torch.float,
dtype=torch.float32,
device=device,
)
targets = torch.tensor([[1, 2]], dtype=torch.int)
logit_lengths = torch.tensor([2], dtype=torch.int)
target_lengths = torch.tensor([2], dtype=torch.int)
logits = logits.to(device=device)
targets = targets.to(device=device)
logit_lengths = logit_lengths.to(device=device)
target_lengths = target_lengths.to(device=device)
targets = torch.tensor([[1, 2]], dtype=torch.int, device=device)
logit_lengths = torch.tensor([2], dtype=torch.int, device=device)
target_lengths = torch.tensor([2], dtype=torch.int, device=device)
logits.requires_grad_(True)
......@@ -84,27 +77,32 @@ def get_data_basic(device):
def get_B1_T10_U3_D4_data(
random=False,
dtype=np.float32,
nan=False,
dtype=torch.float32,
device=torch.device("cpu"),
):
B, T, U, D = 2, 10, 3, 4
data = {}
data["logits"] = np.random.rand(B, T, U, D).astype(dtype)
logits = torch.rand(B, T, U, D, dtype=dtype, device=device)
if not random:
data["logits"].fill(0.1)
if nan:
for i in range(B):
data["logits"][i][0][0][0] = np.nan
data["logit_lengths"] = np.array([10, 10], dtype=np.int32)
data["target_lengths"] = np.array([2, 2], dtype=np.int32)
data["targets"] = np.array([[1, 2], [1, 2]], dtype=np.int32)
logits.fill_(0.1)
logits.requires_grad_(True)
def grad_hook(grad):
logits.saved_grad = grad.clone()
logits.register_hook(grad_hook)
data = {}
data["logits"] = logits
data["logit_lengths"] = torch.tensor([10, 10], dtype=torch.int32, device=device)
data["target_lengths"] = torch.tensor([2, 2], dtype=torch.int32, device=device)
data["targets"] = torch.tensor([[1, 2], [1, 2]], dtype=torch.int32, device=device)
data["blank"] = 0
return data
def get_numpy_data_B1_T2_U3_D5(dtype=np.float32):
logits = np.array(
def get_B1_T2_U3_D5_data(dtype=torch.float32, device=torch.device("cpu")):
logits = torch.tensor(
[
0.1,
0.6,
......@@ -138,15 +136,22 @@ def get_numpy_data_B1_T2_U3_D5(dtype=np.float32):
0.1,
],
dtype=dtype,
device=device,
).reshape(1, 2, 3, 5)
targets = np.array([[1, 2]], dtype=np.int32)
logit_lengths = np.array([2], dtype=np.int32)
target_lengths = np.array([2], dtype=np.int32)
logits.requires_grad_(True)
def grad_hook(grad):
logits.saved_grad = grad.clone()
logits.register_hook(grad_hook)
targets = torch.tensor([[1, 2]], dtype=torch.int32, device=device)
logit_lengths = torch.tensor([2], dtype=torch.int32, device=device)
target_lengths = torch.tensor([2], dtype=torch.int32, device=device)
blank = -1
ref_costs = np.array([5.09566688538], dtype=dtype)
ref_gradients = np.array(
ref_costs = torch.tensor([5.09566688538], dtype=dtype)
ref_gradients = torch.tensor(
[
0.17703132,
-0.39992708,
......@@ -193,10 +198,9 @@ def get_numpy_data_B1_T2_U3_D5(dtype=np.float32):
return data, ref_costs, ref_gradients
def get_numpy_data_B2_T4_U3_D3(dtype=np.float32):
def get_B2_T4_U3_D3_data(dtype=torch.float32, device=torch.device("cpu")):
# Test from D21322854
logits = np.array(
logits = torch.tensor(
[
0.065357,
0.787530,
......@@ -272,17 +276,23 @@ def get_numpy_data_B2_T4_U3_D3(dtype=np.float32):
0.358021,
],
dtype=dtype,
device=device,
).reshape(2, 4, 3, 3)
logits.requires_grad_(True)
targets = np.array([[1, 2], [1, 1]], dtype=np.int32)
logit_lengths = np.array([4, 4], dtype=np.int32)
target_lengths = np.array([2, 2], dtype=np.int32)
def grad_hook(grad):
logits.saved_grad = grad.clone()
logits.register_hook(grad_hook)
targets = torch.tensor([[1, 2], [1, 1]], dtype=torch.int32, device=device)
logit_lengths = torch.tensor([4, 4], dtype=torch.int32, device=device)
target_lengths = torch.tensor([2, 2], dtype=torch.int32, device=device)
blank = 0
ref_costs = np.array([4.2806528590890736, 3.9384369822503591], dtype=dtype)
ref_costs = torch.tensor([4.2806528590890736, 3.9384369822503591], dtype=dtype)
ref_gradients = np.array(
ref_gradients = torch.tensor(
[
-0.186844,
-0.062555,
......@@ -371,30 +381,45 @@ def get_numpy_data_B2_T4_U3_D3(dtype=np.float32):
return data, ref_costs, ref_gradients
def get_numpy_random_data(
max_B=8, max_T=128, max_U=32, max_D=40, blank=-1, dtype=np.float32, seed=None
def get_random_data(
max_B=8,
max_T=128,
max_U=32,
max_D=40,
blank=-1,
dtype=torch.float32,
device=torch.device("cpu"),
seed=None,
):
if seed is not None:
np.random.seed(seed=seed)
torch.manual_seed(seed=seed)
if blank != -1:
raise ValueError("blank != -1 is not supported yet.")
B = np.random.randint(low=1, high=max_B)
T = np.random.randint(low=5, high=max_T)
U = np.random.randint(low=5, high=max_U)
D = np.random.randint(low=2, high=max_D)
logit_lengths = np.random.randint(low=5, high=T + 1, size=(B,), dtype=np.int32)
target_lengths = np.random.randint(low=5, high=U + 1, size=(B,), dtype=np.int32)
max_src_length = np.max(logit_lengths)
max_tgt_length = np.max(target_lengths)
targets = np.random.randint(
low=0, high=D - 1, size=(B, max_tgt_length), dtype=np.int32
random.seed(0)
B = random.randint(1, max_B - 1)
T = random.randint(5, max_T - 1)
U = random.randint(5, max_U - 1)
D = random.randint(2, max_D - 1)
logit_lengths = torch.randint(low=5, high=T + 1, size=(B,), dtype=torch.int32, device=device)
target_lengths = torch.randint(low=5, high=U + 1, size=(B,), dtype=torch.int32, device=device)
max_src_length = torch.max(logit_lengths)
max_tgt_length = torch.max(target_lengths)
targets = torch.randint(
low=0, high=D - 1, size=(B, max_tgt_length), dtype=torch.int32, device=device
)
logits = np.random.random_sample(
size=(B, max_src_length, max_tgt_length + 1, D)
).astype(dtype=dtype)
logits = torch.rand(
size=(B, max_src_length, max_tgt_length + 1, D),
dtype=dtype,
device=device,
).requires_grad_(True)
def grad_hook(grad):
logits.saved_grad = grad.clone()
logits.register_hook(grad_hook)
return {
"logits": logits,
......@@ -405,44 +430,6 @@ def get_numpy_random_data(
}
def numpy_to_torch(data, device, requires_grad=True):
logits = torch.from_numpy(data["logits"]).to(device=device)
targets = torch.from_numpy(data["targets"]).to(device=device)
logit_lengths = torch.from_numpy(data["logit_lengths"]).to(device=device)
target_lengths = torch.from_numpy(data["target_lengths"]).to(device=device)
if "nbest_wers" in data:
data["nbest_wers"] = torch.from_numpy(data["nbest_wers"]).to(device=device)
if "nbest_scores" in data:
data["nbest_scores"] = torch.from_numpy(data["nbest_scores"]).to(
device=device
)
logits = torch.autograd.Variable(logits, requires_grad=requires_grad)
logit_lengths = torch.autograd.Variable(logit_lengths)
target_lengths = torch.autograd.Variable(target_lengths)
targets = torch.autograd.Variable(targets)
if device == torch.device("cpu"):
logits = logits.cpu()
elif device == torch.device("cuda"):
logits = logits.cuda()
else:
raise ValueError("unrecognized device = {}".format(device))
def grad_hook(grad):
logits.saved_grad = grad.clone()
logits.register_hook(grad_hook)
data["logits"] = logits
data["logit_lengths"] = logit_lengths
data["target_lengths"] = target_lengths
data["targets"] = targets
return data
def skipIfNoRNNT(test_item):
try:
torch.ops.torchaudio.rnnt_loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment