Unverified Commit a283cddf authored by Daniel Povey's avatar Daniel Povey Committed by GitHub
Browse files

Merge pull request #4 from pkufool/fast_rnnt

Sync with k2 rnnt_loss
parents b5828e2b 182fe8de
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (authors: Daniel Povey,
# Wei Kang)
#
# See ../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# To run this single test, use
#
# ctest --verbose -R rnnt_loss_test_py
import unittest
import fast_rnnt
import random
import torch
class TestRnntLoss(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.devices = [torch.device("cpu")]
if torch.cuda.is_available() and fast_rnnt.with_cuda():
cls.devices.append(torch.device("cuda", 0))
if torch.cuda.device_count() > 1:
torch.cuda.set_device(1)
cls.devices.append(torch.device("cuda", 1))
try:
import torchaudio
import torchaudio.functional
if hasattr(torchaudio.functional, "rnnt_loss"):
cls.has_torch_rnnt_loss = True
else:
cls.has_torch_rnnt_loss = False
print(
f"Current torchaudio version: {torchaudio.__version__}\n"
"Skipping the tests of comparing rnnt loss with torch "
"one, to enable these tests please install a "
"version >= 0.10.0"
)
except ImportError as e:
cls.has_torch_rnnt_loss = False
print(
f"Import torchaudio error, error message: {e}\n"
"Skipping the tests of comparing rnnt loss with torch "
"one, to enable these tests, please install torchaudio "
"with version >= 0.10.0"
)
def test_rnnt_loss_basic(self):
B = 1
S = 3
T = 4
# C = 3
for device in self.devices:
# lm: [B][S+1][C]
lm = torch.tensor(
[[[0, 0, 1], [0, 1, 1], [1, 0, 1], [2, 2, 0]]],
dtype=torch.float,
device=device,
)
# am: [B][T][C]
am = torch.tensor(
[[[0, 1, 2], [0, 0, 0], [0, 2, 4], [0, 3, 3]]],
dtype=torch.float,
device=device,
)
termination_symbol = 2
symbols = torch.tensor([[0, 1, 0]], dtype=torch.long, device=device)
px, py = fast_rnnt.get_rnnt_logprobs(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
)
assert px.shape == (B, S, T + 1)
assert py.shape == (B, S + 1, T)
assert symbols.shape == (B, S)
m = fast_rnnt.mutual_information_recursion(px=px, py=py, boundary=None)
if device == torch.device("cpu"):
expected = -m
assert torch.allclose(-m, expected.to(device))
# test rnnt_loss_simple
m = fast_rnnt.rnnt_loss_simple(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=None,
reduction="none",
)
assert torch.allclose(m, expected.to(device))
# test rnnt_loss_smoothed
m = fast_rnnt.rnnt_loss_smoothed(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
lm_only_scale=0.0,
am_only_scale=0.0,
boundary=None,
reduction="none",
)
assert torch.allclose(m, expected.to(device))
probs = am.unsqueeze(2) + lm.unsqueeze(1)
# test rnnt_loss
m = fast_rnnt.rnnt_loss(
logits=probs,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=None,
reduction="none",
)
assert torch.allclose(m, expected.to(device))
# compare with torchaudio rnnt_loss
if self.has_torch_rnnt_loss:
import torchaudio.functional
m = torchaudio.functional.rnnt_loss(
logits=probs,
targets=symbols.int(),
logit_lengths=torch.tensor(
[T] * B, dtype=torch.int32, device=device
),
target_lengths=torch.tensor(
[S] * B, dtype=torch.int32, device=device
),
blank=termination_symbol,
reduction="none",
)
assert torch.allclose(m, expected.to(device))
# should be invariant to adding a constant for any frame.
lm += torch.randn(B, S + 1, 1, device=device)
am += torch.randn(B, T, 1, device=device)
m = fast_rnnt.rnnt_loss_simple(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=None,
reduction="none",
)
assert torch.allclose(m, expected.to(device))
m = fast_rnnt.rnnt_loss_smoothed(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
lm_only_scale=0.0,
am_only_scale=0.0,
boundary=None,
reduction="none",
)
assert torch.allclose(m, expected.to(device))
probs = am.unsqueeze(2) + lm.unsqueeze(1)
m = fast_rnnt.rnnt_loss(
logits=probs,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=None,
reduction="none",
)
assert torch.allclose(m, expected.to(device))
def test_rnnt_loss_random(self):
B = 5
S = 20
T = 300
C = 100
frames = torch.randint(S, T, (B,))
seq_length = torch.randint(3, S - 1, (B,))
T = torch.max(frames)
S = torch.max(seq_length)
am_ = torch.randn((B, T, C), dtype=torch.float32)
lm_ = torch.randn((B, S + 1, C), dtype=torch.float32)
symbols_ = torch.randint(0, C - 1, (B, S))
termination_symbol = C - 1
boundary_ = torch.zeros((B, 4), dtype=torch.int64)
boundary_[:, 2] = seq_length
boundary_[:, 3] = frames
for modified in [True, False]:
for device in self.devices:
# lm: [B][S+1][C]
lm = lm_.to(device)
# am: [B][T][C]
am = am_.to(device)
symbols = symbols_.to(device)
boundary = boundary_.to(device)
px, py = fast_rnnt.get_rnnt_logprobs(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
)
assert px.shape == (B, S, T) if modified else (B, S, T + 1)
assert py.shape == (B, S + 1, T)
assert symbols.shape == (B, S)
m = fast_rnnt.mutual_information_recursion(
px=px, py=py, boundary=boundary
)
if device == torch.device("cpu"):
expected = -torch.mean(m)
assert torch.allclose(-torch.mean(m), expected.to(device))
m = fast_rnnt.rnnt_loss_simple(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
)
assert torch.allclose(m, expected.to(device))
m = fast_rnnt.rnnt_loss_smoothed(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
lm_only_scale=0.0,
am_only_scale=0.0,
boundary=boundary,
modified=modified,
)
assert torch.allclose(m, expected.to(device))
probs = am.unsqueeze(2) + lm.unsqueeze(1)
m = fast_rnnt.rnnt_loss(
logits=probs,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
)
assert torch.allclose(m, expected.to(device))
# compare with torchaudio rnnt_loss
if self.has_torch_rnnt_loss and not modified:
import torchaudio.functional
m = torchaudio.functional.rnnt_loss(
logits=probs,
targets=symbols.int(),
logit_lengths=boundary[:, 3].int(),
target_lengths=boundary[:, 2].int(),
blank=termination_symbol,
)
assert torch.allclose(m, expected.to(device))
# should be invariant to adding a constant for any frame.
lm += torch.randn(B, S + 1, 1, device=device)
am += torch.randn(B, T, 1, device=device)
m = fast_rnnt.rnnt_loss_simple(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
)
assert torch.allclose(m, expected.to(device))
probs = am.unsqueeze(2) + lm.unsqueeze(1)
m = fast_rnnt.rnnt_loss(
logits=probs,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
modified=modified,
)
assert torch.allclose(m, expected.to(device))
m = fast_rnnt.rnnt_loss_smoothed(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
lm_only_scale=0.0,
am_only_scale=0.0,
boundary=boundary,
modified=modified,
)
assert torch.allclose(m, expected.to(device))
def test_rnnt_loss_gradient(self):
if self.has_torch_rnnt_loss:
import torchaudio.functional
B = 5
S = 20
T = 300
C = 100
frames = torch.randint(S, T, (B,))
seq_length = torch.randint(3, S - 1, (B,))
T = torch.max(frames)
S = torch.max(seq_length)
am_ = torch.randn((B, T, C), dtype=torch.float32)
lm_ = torch.randn((B, S + 1, C), dtype=torch.float32)
symbols_ = torch.randint(0, C - 1, (B, S))
termination_symbol = C - 1
boundary_ = torch.zeros((B, 4), dtype=torch.int64)
boundary_[:, 2] = seq_length
boundary_[:, 3] = frames
for device in self.devices:
# lm: [B][S+1][C]
lm = lm_.to(device)
# am: [B][T][C]
am = am_.to(device)
symbols = symbols_.to(device)
boundary = boundary_.to(device)
logprobs = am.unsqueeze(2) + lm.unsqueeze(1)
logprobs.requires_grad_()
k2_loss = fast_rnnt.rnnt_loss(
logits=logprobs,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
)
k2_grad = torch.autograd.grad(k2_loss, logprobs)
k2_grad = k2_grad[0]
logprobs2 = logprobs.detach().clone().float()
logprobs2.requires_grad_()
torch_loss = torchaudio.functional.rnnt_loss(
logits=logprobs2,
targets=symbols.int(),
logit_lengths=boundary[:, 3].int(),
target_lengths=boundary[:, 2].int(),
blank=termination_symbol,
)
torch_grad = torch.autograd.grad(torch_loss, logprobs2)
torch_grad = torch_grad[0]
assert torch.allclose(k2_loss, torch_loss, atol=1e-2, rtol=1e-2)
assert torch.allclose(k2_grad, torch_grad, atol=1e-2, rtol=1e-2)
def test_rnnt_loss_smoothed(self):
B = 1
S = 3
T = 4
# C = 3
for device in self.devices:
# lm: [B][S+1][C]
lm = torch.tensor(
[[[0, 0, 1], [0, 1, 1], [1, 0, 1], [2, 2, 0]]],
dtype=torch.float,
device=device,
)
# am: [B][T][C]
am = torch.tensor(
[[[0, 1, 2], [0, 0, 0], [0, 2, 4], [0, 3, 3]]],
dtype=torch.float,
device=device,
)
termination_symbol = 2
symbols = torch.tensor([[0, 1, 0]], dtype=torch.long, device=device)
m = fast_rnnt.rnnt_loss_smoothed(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
lm_only_scale=0.0,
am_only_scale=0.333,
boundary=None,
)
if device == torch.device("cpu"):
expected = m
assert torch.allclose(m, expected.to(device))
# should be invariant to adding a constant for any frame.
lm += torch.randn(B, S + 1, 1, device=device)
am += torch.randn(B, T, 1, device=device)
m = fast_rnnt.rnnt_loss_smoothed(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
lm_only_scale=0.0,
am_only_scale=0.333,
boundary=None,
)
assert torch.allclose(m, expected.to(device))
def test_rnnt_loss_pruned(self):
B = 4
T = 300
S = 50
C = 10
frames = torch.randint(S, T, (B,))
seq_length = torch.randint(3, S - 1, (B,))
T = torch.max(frames)
S = torch.max(seq_length)
am_ = torch.randn((B, T, C), dtype=torch.float64)
lm_ = torch.randn((B, S + 1, C), dtype=torch.float64)
symbols_ = torch.randint(0, C - 1, (B, S))
terminal_symbol = C - 1
boundary_ = torch.zeros((B, 4), dtype=torch.int64)
boundary_[:, 2] = seq_length
boundary_[:, 3] = frames
for modified in [True, False]:
for device in self.devices:
# normal rnnt
am = am_.to(device)
lm = lm_.to(device)
symbols = symbols_.to(device)
boundary = boundary_.to(device)
t_am = am.unsqueeze(2).float()
t_lm = lm.unsqueeze(1).float()
t_prob = t_am + t_lm
# nonlinear transform
t_prob = torch.sigmoid(t_prob)
k2_loss = fast_rnnt.rnnt_loss(
logits=t_prob,
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
modified=modified,
)
print(
f"unpruned rnnt loss with modified {modified} : {k2_loss}"
)
# pruning
k2_simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
modified=modified,
return_grad=True,
reduction="none",
)
for r in range(2, 50, 5):
ranges = fast_rnnt.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=r,
)
# (B, T, r, C)
am_p, lm_p = fast_rnnt.do_rnnt_pruning(am=am, lm=lm, ranges=ranges)
t_prob_p = am_p + lm_p
# nonlinear transform
t_prob_p = torch.sigmoid(t_prob_p)
pruned_loss = fast_rnnt.rnnt_loss_pruned(
logits=t_prob_p,
symbols=symbols,
ranges=ranges,
termination_symbol=terminal_symbol,
boundary=boundary,
modified=modified,
reduction="none",
)
print(f"pruning loss with range {r} : {pruned_loss}")
if __name__ == "__main__":
unittest.main()
[build-system]
requires = ['setuptools>=38.2.5', 'wheel', 'torch>=1.5', 'ninja']
build-backend = "setuptools.build_meta"
\ No newline at end of file
#!/usr/bin/env python
#!/usr/bin/env python3
#
# Copyright (c) 2022 Xiaomi Corporation (author: Wei Kang)
import glob
import os
from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
import re
import shutil
import sys
import setuptools
from setuptools.command.build_ext import build_ext
cur_dir = os.path.dirname(os.path.abspath(__file__))
def cmake_extension(name, *args, **kwargs) -> setuptools.Extension:
kwargs["language"] = "c++"
sources = []
return setuptools.Extension(name, sources, *args, **kwargs)
with open('requirements.txt') as f:
requirements = f.read().splitlines()
class BuildExtension(build_ext):
def build_extension(self, ext: setuptools.extension.Extension):
# build/temp.linux-x86_64-3.8
build_dir = self.build_temp
os.makedirs(build_dir, exist_ok=True)
# build/lib.linux-x86_64-3.8
os.makedirs(self.build_lib, exist_ok=True)
long_description = """
This package implements an efficient parallel algorithm for the computation of
mutual information between sequences with differentiable bindings to PyTorch.
ft_dir = os.path.dirname(os.path.abspath(__file__))
cmake_args = os.environ.get("FT_CMAKE_ARGS", "")
make_args = os.environ.get("FT_MAKE_ARGS", "")
system_make_args = os.environ.get("MAKEFLAGS", "")
Find more details and the most up-to-date information on the project webpage:
[TODO]
"""
if cmake_args == "":
cmake_args = "-DCMAKE_BUILD_TYPE=Release"
if make_args == "" and system_make_args == "":
print("For fast compilation, run:")
print('export FT_MAKE_ARGS="-j"; python setup.py install')
def configure_extensions():
out = [
CppExtension(
'torch_mutual_information_cpu',
[
os.path.join('torch_mutual_information', 'mutual_information_cpu.cpp'),
],
)
]
try:
out.append(
CUDAExtension(
'torch_mutual_information_cuda',
[
os.path.join('torch_mutual_information', 'mutual_information_cuda.cpp'),
os.path.join('torch_mutual_information', 'mutual_information_cuda_kernel.cu'),
],
if "PYTHON_EXECUTABLE" not in cmake_args:
print(f"Setting PYTHON_EXECUTABLE to {sys.executable}")
cmake_args += f" -DPYTHON_EXECUTABLE={sys.executable}"
build_cmd = f"""
cd {self.build_temp}
cmake {cmake_args} {ft_dir}
make {make_args} _fast_rnnt
"""
print(f"build command is:\n{build_cmd}")
ret = os.system(build_cmd)
if ret != 0:
raise Exception(
"\nBuild fast_rnnt failed. Please check the error "
"message.\n"
"You can ask for help by creating an issue on GitHub.\n"
"\nClick:\n"
"\thttps://github.com/danpovey/fast_rnnt/issues/new\n" # noqa
)
)
except Exception as e:
print(f'Failed to build CUDA extension, this part of the package will not work. Reason: {str(e)}')
return out
setup(
name='torch_mutual_information',
version='1.0.2',
description='Mutual information between sequences of vectors',
long_description=long_description,
long_description_content_type='text/markdown',
install_requires=requirements,
python_requires='>=3.6',
packages=find_packages(),
author='Dan Povey',
license='BSD',
ext_modules=configure_extensions(),
cmdclass={
'build_ext': BuildExtension
lib_so = glob.glob(f"{build_dir}/lib/*.so*")
for so in lib_so:
print(f"Copying {so} to {self.build_lib}/")
shutil.copy(f"{so}", f"{self.build_lib}/")
# macos
lib_so = glob.glob(f"{build_dir}/lib/*.dylib*")
for so in lib_so:
print(f"Copying {so} to {self.build_lib}/")
shutil.copy(f"{so}", f"{self.build_lib}/")
def read_long_description():
with open("README.md", encoding="utf8") as f:
readme = f.read()
return readme
def get_package_version():
with open("CMakeLists.txt") as f:
content = f.read()
latest_version = re.search(r"set\(FT_VERSION (.*)\)", content).group(1)
latest_version = latest_version.strip('"')
return latest_version
def get_requirements():
with open("requirements.txt", encoding="utf8") as f:
requirements = f.read().splitlines()
return requirements
package_name = "fast_rnnt"
with open(
"fast_rnnt/python/fast_rnnt/__init__.py", "a"
) as f:
f.write(f"__version__ = '{get_package_version()}'\n")
setuptools.setup(
name=package_name,
version=get_package_version(),
author="Dan Povey",
author_email="dpovey@gmail.com",
package_dir={
package_name: "fast_rnnt/python/fast_rnnt",
},
keywords=[
'pytorch', 'sequence', 'mutual', 'information'
packages=[package_name],
url="https://github.com/danpovey/fast_rnnt",
description="Fast and memory-efficient RNN-T loss.",
long_description=read_long_description(),
long_description_content_type="text/markdown",
install_requires=get_requirements(),
ext_modules=[cmake_extension("_fast_rnnt")],
cmdclass={"build_ext": BuildExtension},
zip_safe=False,
classifiers=[
"Programming Language :: C++",
"Programming Language :: Python",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
license="Apache licensed, as found in the LICENSE file",
)
import os
import random
import time
import unittest
import torch
from tqdm import tqdm
from torch_discounted_cumsum import discounted_cumsum_left, discounted_cumsum_right
def get_grad(param, out):
out.sum().backward()
grad = param.grad.clone()
del param.grad
return grad
def discounted_cumsum_left_gold(input, gamma):
assert input.dim() == 2
assert 0 <= gamma <= 1
out = []
last_col = torch.zeros((input.shape[0], 1), dtype=input.dtype, device=input.device)
for i in range(input.shape[1]):
cur_col = input[:, i].unsqueeze(-1)
last_col = cur_col + gamma * last_col
out.append(last_col)
out = torch.cat(out, dim=1)
return out
def discounted_cumsum_right_gold(input, gamma):
assert input.dim() == 2
assert 0 <= gamma <= 1
out = []
last_col = torch.zeros((input.shape[0], 1), dtype=input.dtype, device=input.device)
for i in reversed(range(input.shape[1])):
cur_col = input[:, i].unsqueeze(-1)
last_col = cur_col + gamma * last_col
out.insert(0, last_col)
out = torch.cat(out, dim=1)
return out
def discounted_cumsum_lib(x, gamma, dir):
return {
'left': discounted_cumsum_left,
'right': discounted_cumsum_right,
}[dir](x, gamma)
def discounted_cumsum_gold(x, gamma, dir):
return {
'left': discounted_cumsum_left_gold,
'right': discounted_cumsum_right_gold,
}[dir](x, gamma)
def compute_linf(batchsz, veclen, dir, gamma=0.99, dtype=torch.float32, cuda=False, data='randn', tol=1e-3, seed=2021):
torch.manual_seed(seed)
if data == 'randn':
x = torch.randn((batchsz, veclen), dtype=dtype)
elif data == 'ones':
x = torch.ones((batchsz, veclen), dtype=dtype)
else:
raise ValueError('Invalid data generation identifier')
if cuda:
x = x.cuda()
x = torch.nn.Parameter(x)
out_gold = discounted_cumsum_gold(x, gamma, dir)
grad_gold = get_grad(x, out_gold)
out_lib = discounted_cumsum_lib(x, gamma, dir)
grad_lib = get_grad(x, out_lib)
out_linf = (out_lib - out_gold).abs().max().item()
grad_linf = (grad_lib - grad_gold).abs().max().item()
if out_linf >= tol or grad_linf >= tol:
print(f'x={x}\nout_gold={out_gold}\nout_lib={out_lib}\ngrad_gold={grad_gold}\ngrad_lib={grad_lib}\n')
return out_linf, grad_linf
class TestDiscountedCumSum(unittest.TestCase):
def test_validity(self):
print('Testing validity...')
is_cuda = os.environ.get('CUDA_VISIBLE_DEVICES', '') != ''
for cuda in (True, False):
if cuda and not is_cuda:
print('Skipping validity CUDA tests')
continue
rng = random.Random(2021)
with tqdm(total=2*2*2*17) as pbar:
for data in ('ones', 'randn'):
for dtype in (torch.float32, torch.float64):
for i in range(2):
batchsz = 8 ** i
for j in range(17):
veclen = max(1, 2 ** j + rng.randint(-1, 1))
gamma = rng.random()
seed = rng.randint(0, 2 ** 16)
dir = rng.choice(['left', 'right'])
tol = 2e-3
out_linf, grad_linf = compute_linf(
batchsz, veclen, dir, gamma, dtype, cuda, data, tol, seed
)
msg = f'Validity test failed with batchsz={batchsz}, veclen={veclen}, dir={dir}, ' \
f'gamma={gamma}, dtype={dtype}, cuda={cuda}, data={data}, seed={seed}, ' \
f'out_linf={out_linf}, grad_linf={grad_linf}'
self.assertLess(out_linf, tol, msg)
self.assertLess(grad_linf, tol, msg)
pbar.update(1)
def test_precision(self):
print('Testing precision...')
is_cuda = os.environ.get('CUDA_VISIBLE_DEVICES', '') != ''
if not is_cuda:
print('Skipping precision tests')
return
batchsz = 1
veclen = 10000
gamma = 0.99
dir = 'right'
for data in ('ones', 'randn'):
if data == 'ones':
precision_factor = 2.0
else:
precision_factor = 1.1
torch.manual_seed(2021)
if data == 'randn':
x_32 = torch.randn((batchsz, veclen), dtype=torch.float32)
elif data == 'ones':
x_32 = torch.ones((batchsz, veclen), dtype=torch.float32)
else:
raise ValueError('Invalid data generation identifier')
x_32 = x_32.cuda()
x_64 = x_32.double()
gold_64 = discounted_cumsum_gold(x_64, gamma, dir)
gold_32 = discounted_cumsum_gold(x_32, gamma, dir).double()
lib_32 = discounted_cumsum_lib(x_32, gamma, dir).double()
err_32_gold = (gold_32 - gold_64).abs().max().item()
err_32_lib = (lib_32 - gold_64).abs().max().item()
msg = f'Precision improvement test failed with data={data}, ' \
f'err_32_gold={err_32_gold}, err_32_lib={err_32_lib}'
self.assertLess(precision_factor * err_32_lib, err_32_gold, msg)
print(f'data={data}\nerr_32_gold={err_32_gold:10.8f}\nerr_32_lib ={err_32_lib:10.8f}')
def test_speed(self):
print('Testing speed...')
is_cuda = os.environ.get('CUDA_VISIBLE_DEVICES', '') != ''
NUM_RUNS = 30
NUM_RUNS_GOLD = 6
if not is_cuda:
print('Skipping speed tests')
return
gamma = 0.99
x_32 = torch.randn((1, 100000), dtype=torch.float32)
x_32 += torch.ones_like(x_32)
x_32_gpu = x_32.cuda()
timer = time.clock_gettime(time.CLOCK_MONOTONIC)
for _ in tqdm(range(NUM_RUNS_GOLD), desc='gold', leave=True):
discounted_cumsum_right_gold(x_32, gamma)
dur_gold = time.clock_gettime(time.CLOCK_MONOTONIC) - timer
dur_gold = dur_gold * NUM_RUNS / NUM_RUNS_GOLD
timer = time.clock_gettime(time.CLOCK_MONOTONIC)
for _ in tqdm(range(NUM_RUNS), desc='lib_cpu', leave=True):
discounted_cumsum_right(x_32, gamma)
dur_lib_cpu = time.clock_gettime(time.CLOCK_MONOTONIC) - timer
timer = time.clock_gettime(time.CLOCK_MONOTONIC)
for _ in tqdm(range(NUM_RUNS), desc='lib_cuda', leave=True):
discounted_cumsum_right(x_32_gpu, gamma)
dur_lib_cuda = time.clock_gettime(time.CLOCK_MONOTONIC) - timer
print(f'dur_gold: {dur_gold:7.4f} sec')
print(f'dur_lib_cpu: {dur_lib_cpu:7.4f} sec')
print(f'dur_lib_cuda: {dur_lib_cuda:7.4f} sec')
print(f'speedup gold -> lib_cpu: {dur_gold / dur_lib_cpu:5.2f}')
print(f'speedup gold -> lib_cuda: {dur_gold / dur_lib_cuda:5.2f}')
print(f'speedup lib_cpu -> lib_cuda: {dur_lib_cpu / dur_lib_cuda:5.2f}')
if __name__ == '__main__':
unittest.main()
from .mutual_information import mutual_information_recursion, joint_mutual_information_recursion
from .rnnt import get_rnnt_logprobs, rnnt_loss_simple, rnnt_loss_aux
import os
import torch
from torch import Tensor
from typing import Tuple, Optional, Sequence
from torch.utils.cpp_extension import load
VERBOSE = False
def _resolve(name):
return os.path.join(os.path.dirname(os.path.realpath(__file__)), name)
try:
import torch_mutual_information_cpu
except ImportError:
if VERBOSE:
print('Falling back to JIT compiling torch_mutual_information_cpu')
torch_mutual_information_cpu = load(
name='torch_mutual_information_cpu',
sources=[
_resolve('mutual_information_cpu.cpp'),
],
verbose=VERBOSE,
)
try:
import torch_mutual_information_cuda
except ImportError:
if VERBOSE:
print('Falling back to JIT compiling torch_mutual_information_cuda')
torch_mutual_information_cuda = None
if torch.cuda.is_available():
torch_mutual_information_cuda = load(
name='torch_mutual_information_cuda',
sources=[
_resolve('mutual_information_cuda.cpp'),
_resolve('mutual_information_cuda_kernel.cu'),
],
verbose=VERBOSE,
)
def _mutual_information_forward_dispatcher(px: torch.Tensor, py: torch.Tensor,
boundary: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
if px.is_cuda:
if torch_mutual_information_cuda is None:
raise EnvironmentError(f'Failed to load native CUDA module')
return torch_mutual_information_cuda.mutual_information_cuda(
px, py, boundary, p)
else:
return torch_mutual_information_cpu.mutual_information_cpu(
px, py, boundary, p)
def _mutual_information_backward_dispatcher(px: torch.Tensor, py: torch.Tensor,
boundary: torch.Tensor, p: torch.Tensor,
ans_grad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if px.is_cuda:
if torch_mutual_information_cuda is None:
raise EnvironmentError(f'Failed to load native CUDA module')
overwrite_ans_grad = True
if overwrite_ans_grad:
ans_grad_copy = ans_grad.clone()
ans = tuple(torch_mutual_information_cuda.mutual_information_backward_cuda(
px, py, boundary, p, ans_grad_copy, overwrite_ans_grad))
if overwrite_ans_grad:
if not torch.allclose(ans_grad, ans_grad_copy, rtol=1.0e-02):
print(f"Warning: possible excesssive roundoff in mutual information backward "
f"recursion: {ans_grad} vs. {ans_grad_copy}");
return ans
else:
return tuple(torch_mutual_information_cpu.mutual_information_backward_cpu(
px, py, boundary, p, ans_grad))
class MutualInformationRecursionFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, px: torch.Tensor, py: torch.Tensor, boundary: Optional[torch.Tensor]) -> torch.Tensor:
(B, S, T1) = px.shape
T = T1 - 1;
assert py.shape == (B, S + 1, T)
if boundary is not None:
assert boundary.shape == (B, 4)
else:
boundary = torch.zeros(0, 0, dtype=torch.int64, device=px.device)
# p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is the
# the mutual information of the pair of subsequences of x and y that are of
# length s and t respectively. p[0][0] will be 0.0 and p[S][T] is
# the mutual information of the entire pair of sequences, i.e. of lengths
# S and T respectively.
# It is computed as follows (in C++ and CUDA):
# p[b,0,0] = 0.0
# p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
# p[b,s,t-1] + py[b,s,t-1])
# if s > 0 or t > 0,
# treating values with any -1 index as -infinity.
# .. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
p = torch.empty(B, S + 1, T + 1, device=px.device, dtype=px.dtype)
ans = _mutual_information_forward_dispatcher(px, py, boundary, p)
# print(f"p = {p}, boundary = {boundary}, psum={p.sum()}")
if px.requires_grad or py.requires_grad:
ctx.save_for_backward(px, py, boundary, p)
return ans
@staticmethod
def backward(ctx, ans_grad: Tensor) -> Tuple[torch.Tensor, torch.Tensor, None]:
(px, py, boundary, p) = ctx.saved_tensors
(px_grad, py_grad) = _mutual_information_backward_dispatcher(px, py, boundary, p, ans_grad)
return (px_grad, py_grad, None)
def mutual_information_recursion(px, py, boundary=None):
"""A recursion that is useful in computing mutual information between two sequences of
real vectors, but may be useful more generally in sequence-to-sequence tasks where
monotonic alignment between pairs of sequences is desired. The definitions of
the arguments are definitions that would be used when computing this type of
mutual information, but you can also view them as arbitrary quantities and just
make use of the formula computed by this function.
Args:
px: A torch.Tensor of some floating point type, with shape [B][S][T+1],
where B is the batch size, S is the length of the 'x' sequence
(including representations of EOS symbols but not BOS symbols), and S is the
length of the 'y' sequence (including representations of
EOS symbols but not BOS symbols). In the mutual information application,
px[b][s][t] would represent the following log odds ratio; ignoring
the b index on the right to make the notation more compact,
px[b][s][t] = log [ p(x_s | x_{0..s-1}, y_{0..t-1}) / p(x_s) ]
This expression also implicitly includes the log-probability of
choosing to generate an x value as opposed to a y value. In
practice it might be computed as a + b, where a is the log
probability of choosing to extend the sequence of length (s,t)
with an x as opposed to a y value; and b might in practice be
of the form:
log(N exp f(x_s, y_{t-1}) / sum_t' exp f(x_s, y_t'))
where N is the number of terms that the sum over t' included, which
might include some or all of the other sequences as well as this one.
Note: we don't require px and py to be contiguous, but the
code assumes for optimization purposes that the T axis has
stride 1.
py: A torch.Tensor of the same dtype as px, with shape [B][S+1][T],
representing
py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ]
This function does not treat x and y differently; the only difference
is that for optimization purposes we assume the last axis (the t axis)
has stride of 1; this is true if px and py are contiguous.
boundary: If supplied, a torch.LongTensor of shape [B][4], where each row contains
[s_begin, t_begin, s_end, t_end], with 0 <= s_begin <= s_end < S and
0 <= t_begin <= t_end < T (this implies that empty sequences are allowed). If not supplied, the values
[0, 0, S, T] will be assumed. These are the beginning and
one-past-the-last positions in the x and y sequences
respectively, and can be used if not all sequences are of the same length.
Returns:
Returns a torch.Tensor of shape [B], containing the log of the mutual
information between the b'th pair of sequences. This is defined by
the following recursion on p[b,s,t] (where p is of shape [B,S+1,T+1]),
representing a mutual information between sub-sequences of lengths s and t:
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
(if s > 0 or t > 0)
where we handle edge cases by treating quantities with negative indexes
as -infinity. The extension to cases where the boundaries are specified
should be obvious; it just works on shorter sequences with offsets into
px and py.
"""
assert px.ndim == 3
B, S, T1 = px.shape
T = T1 - 1
assert py.shape == (B, S + 1, T)
assert px.dtype == py.dtype
(B, S, T) = px.shape
if boundary is not None:
assert boundary.dtype == torch.int64
assert boundary.shape == (B, 4)
for [ s_begin, t_begin, s_end, t_end ] in boundary.to('cpu').tolist():
assert 0 <= s_begin <= s_end <= S
assert 0 <= t_begin <= t_end <= T
# The following assertions are for efficiency
assert px.stride()[-1] == 1
assert py.stride()[-1] == 1
return MutualInformationRecursionFunction.apply(px, py, boundary)
def _inner(a: Tensor, b: Tensor) -> Tensor:
"""
Does inner product on the last dimension, with expected broadcasting, i.e. equivalent to
(a * b).sum(dim=-1)
without creating a large temporary.
"""
assert a.shape[-1] == b.shape[-1] # last last dim be K
a = a.unsqueeze(-2) # (..., 1, K)
b = b.unsqueeze(-1) # (..., K, 1)
c = torch.matmul(a, b) # (..., 1, 1)
return c.squeeze(-1).squeeze(-1)
def joint_mutual_information_recursion(px: Sequence[Tensor], py: Sequence[Tensor],
boundary: Optional[Tensor] = None) -> Sequence[Tensor]:
"""A recursion that is useful for modifications of RNN-T and similar loss functions,
where the recursion probabilities have a number of terms and you want them reported
separately. See mutual_information_recursion() for more documentation of the
basic aspects of this.
Args:
px: a sequence of Tensors, each of the same shape [B][S][T+1]
py: a sequence of Tensor, each of the same shape [B][S+1][T], the sequence must be
the same length as px.
boundary: optionally, a LongTensor of shape [B][4] containing rows
[s_begin, t_begin, s_end, t_end], with 0 <= s_begin <= s_end < S and
0 <= t_begin <= t_end < T, defaulting to [0, 0, S, T].
These are the beginning and
one-past-the-last positions in the x and y sequences
respectively, and can be used if not all sequences are of the same length.
Returns:
a Tensor of shape (len(px), B),
whose sum over dim 0 is the total log-prob of the recursion mentioned below, per sequence.
The first element of the sequence of length len(px) is "special", in that it has an offset term
reflecting the difference between sum-of-log and log-of-sum; for more interpretable
loss values, the "main" part of your loss function should be first.
The recursion below applies if boundary == None, when it defaults
to (0, 0, S, T); where px_sum, py_sum are the sums of the elements of px and py:
p = tensor of shape (B, S+1, T+1), containing -infinity
p[b,0,0] = 0.0
# do the following in loop over s and t:
p[b,s,t] = log_add(p[b,s-1,t] + px_sum[b,s-1,t],
p[b,s,t-1] + py_sum[b,s,t-1])
(if s > 0 or t > 0)
return b[:][S][T]
This function lets you implement the above recursion efficiently, except
that it gives you a breakdown of the contribution from all the elements of
px and py separately. As noted above, the first element of the
sequence is "special".
"""
N = len(px)
assert len(py) == N and N > 0
B, S, T1 = px[0].shape
T = T1 - 1
assert py[0].shape == (B, S + 1, T)
assert px[0].dtype == py[0].dtype
px_cat = torch.stack(px, dim=0) # (N, B, S, T+1)
py_cat = torch.stack(py, dim=0) # (N, B, S+1, T)
px_tot = px_cat.sum(dim=0) # (B, S, T+1)
py_tot = py_cat.sum(dim=0) # (B, S+1, T)
if boundary is not None:
assert boundary.dtype == torch.int64
assert boundary.shape == (B, 4)
for [ s_begin, t_begin, s_end, t_end ] in boundary.to('cpu').tolist():
assert 0 <= s_begin <= s_end <= S
assert 0 <= t_begin <= t_end <= T
else:
boundary = torch.zeros(0, 0, dtype=torch.int64, device=px_tot.device)
px_tot, py_tot = px_tot.contiguous(), py_tot.contiguous()
# The following assertions are for efficiency
assert px_tot.stride()[-1] == 1 and px_tot.ndim == 3
assert py_tot.stride()[-1] == 1 and py_tot.ndim == 3
p = torch.empty(B, S + 1, T + 1, device=px_tot.device, dtype=px_tot.dtype)
# note, tot_probs is without grad.
tot_probs = _mutual_information_forward_dispatcher(px_tot, py_tot, boundary, p)
# this is a kind of "fake gradient" that we use, in effect to compute
# occupation probabilities. The backprop will work regardless of the
# actual derivative w.r.t. the total probs.
ans_grad = torch.ones(B, device=px_tot.device, dtype=px_tot.dtype)
(px_grad, py_grad) = _mutual_information_backward_dispatcher(px_tot, py_tot,
boundary, p, ans_grad)
px_grad, py_grad = px_grad.reshape(1, B, -1), py_grad.reshape(1, B, -1)
px_cat, py_cat = px_cat.reshape(N, B, -1), py_cat.reshape(N, B, -1)
x_prods = _inner(px_grad, px_cat) # (N, B)
y_prods = _inner(py_grad, py_cat) # (N, B)
# If all the occupation counts were exactly 1.0 (i.e. no partial counts),
# "prods" should be equal to "tot_probs"; however, in general, "tot_probs"
# will be more positive due to the difference between log-of-sum and
# sum-of-log
prods = x_prods + y_prods # (N, B)
with torch.no_grad():
offset = tot_probs - prods.sum(dim=0) # (B,)
prods[0] += offset
return prods # (N, B)
#include <math.h> // for log1p, log1pf
#include <torch/extension.h>
inline double Exp(double x) {
return exp(x);
}
inline double Exp(float x) {
return expf(x);
}
// returns log(exp(x) + exp(y)).
inline double LogAdd(double x, double y) {
double diff;
if (x < y) {
diff = x - y;
x = y;
} else {
diff = y - x;
}
// diff is negative. x is now the larger one.
if (diff >= -1000) {
double res;
res = x + log1p(exp(diff));
return res;
}
return x; // return the larger one.
}
// returns log(exp(x) + exp(y)).
inline float LogAdd(float x, float y) {
float diff;
if (x < y) {
diff = x - y;
x = y;
} else {
diff = y - x;
}
// diff is negative. x is now the larger one.
if (diff >= -200) {
float res;
res = x + log1pf(expf(diff));
return res;
}
return x; // return the larger one.
}
// forward of mutual_information. See """... """ comment of `mutual_information` in
// mutual_information.py for documentation of the behavior of this function.
// px: of shape [B, S, T+1] where
torch::Tensor mutual_information_cpu(torch::Tensor px,
torch::Tensor py,
torch::Tensor boundary,
torch::Tensor p) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
TORCH_CHECK(boundary.dim() == 2, "boundary must be 2-dimensional.");
TORCH_CHECK(px.device().is_cpu() && py.device().is_cpu() && p.device().is_cpu(),
"inputs must be CPU tensors");
auto scalar_t = px.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device());
const int B = px.size(0),
S = px.size(1),
T = px.size(2) - 1;
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
TORCH_CHECK((boundary.size(0) == 0 && boundary.size(1) == 0) ||
(boundary.size(0) == B && boundary.size(1) == 4));
TORCH_CHECK(boundary.device().is_cpu() &&
boundary.dtype() == torch::kInt64);
torch::Tensor ans = torch::empty({B}, opts);
bool has_boundary = (boundary.size(0) != 0);
AT_DISPATCH_FLOATING_TYPES(px.scalar_type(), "mutual_information_cpu_loop", ([&] {
auto px_a = px.packed_accessor32<scalar_t, 3>(),
py_a = py.packed_accessor32<scalar_t, 3>(),
p_a = p.packed_accessor32<scalar_t, 3>();
auto boundary_a = boundary.packed_accessor32<int64_t, 2>();
auto ans_a = ans.packed_accessor32<scalar_t, 1>();
for (int b = 0; b < B; b++) {
int s_begin, s_end, t_begin, t_end;
if (has_boundary) {
s_begin = boundary_a[b][0];
t_begin = boundary_a[b][1];
s_end = boundary_a[b][2];
t_end = boundary_a[b][3];
} else {
s_begin = 0;
t_begin = 0;
s_end = S;
t_end = T;
}
p_a[b][s_begin][t_begin] = 0.0;
for (int s = s_begin + 1; s <= s_end; ++s)
p_a[b][s][t_begin] = p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
for (int t = t_begin + 1; t <= t_end; ++t)
p_a[b][s_begin][t] = p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1];
for (int s = s_begin + 1; s <= s_end; ++s) {
scalar_t p_s_t1 = p_a[b][s][t_begin];
for (int t = t_begin + 1; t <= t_end; ++t) {
// The following statement is a small optimization of:
// p_a[b][s][t] = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t],
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// .. which obtains p_a[b][s][t - 1] from a register.
p_a[b][s][t] = p_s_t1 = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t],
p_s_t1 + py_a[b][s][t - 1]);
}
}
ans_a[b] = p_a[b][s_end][t_end];
}
}));
return ans;
}
// backward of mutual_information. Returns (px_grad, py_grad).
// p corresponds to what we computed in the forward pass.
std::vector<torch::Tensor> mutual_information_backward_cpu(
torch::Tensor px,
torch::Tensor py,
torch::Tensor boundary,
torch::Tensor p,
torch::Tensor ans_grad) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
TORCH_CHECK(boundary.dim() == 2, "boundary must be 2-dimensional.");
TORCH_CHECK(ans_grad.dim() == 1, "ans_grad must be 3-dimensional.");
TORCH_CHECK(px.device().is_cpu() && py.device().is_cpu() && p.device().is_cpu()
&& ans_grad.device().is_cpu(),
"inputs must be CPU tensors");
auto scalar_t = px.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device());
const int B = px.size(0),
S = px.size(1),
T = px.size(2) - 1;
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
TORCH_CHECK((boundary.size(0) == 0 && boundary.size(1) == 0) ||
(boundary.size(0) == B && boundary.size(1) == 4));
TORCH_CHECK(boundary.device().is_cpu() &&
boundary.dtype() == torch::kInt64);
bool has_boundary = (boundary.size(0) != 0);
torch::Tensor p_grad = torch::zeros({B, S + 1, T + 1}, opts),
px_grad = (has_boundary ? torch::zeros({B, S, T + 1}, opts) :
torch::empty({B, S, T + 1}, opts)),
py_grad = (has_boundary ? torch::zeros({B, S + 1, T}, opts) :
torch::empty({B, S + 1, T}, opts));
AT_DISPATCH_FLOATING_TYPES(px.scalar_type(), "mutual_information_cpu_backward_loop", ([&] {
auto px_a = px.packed_accessor32<scalar_t, 3>(),
// py_a = py.packed_accessor32<scalar_t, 3>(),
p_a = p.packed_accessor32<scalar_t, 3>(),
p_grad_a = p_grad.packed_accessor32<scalar_t, 3>(),
px_grad_a = px_grad.packed_accessor32<scalar_t, 3>(),
py_grad_a = py_grad.packed_accessor32<scalar_t, 3>();
auto ans_grad_a = ans_grad.packed_accessor32<scalar_t, 1>();
auto boundary_a = boundary.packed_accessor32<int64_t, 2>();
for (int b = 0; b < B; b++) {
int s_begin, s_end, t_begin, t_end;
if (has_boundary) {
s_begin = boundary_a[b][0];
t_begin = boundary_a[b][1];
s_end = boundary_a[b][2];
t_end = boundary_a[b][3];
} else {
s_begin = 0;
s_end = S;
t_begin = 0;
t_end = T;
}
// Backprop for: ans_a[b] = p_a[b][s_end][t_end];
p_grad_a[b][s_end][t_end] = ans_grad_a[b];
for (int s = s_end; s > s_begin; --s) {
for (int t = t_end; t > t_begin; --t) {
// The s,t indexes correspond to
// The statement we are backpropagating here is:
// p_a[b][s][t] = LogAdd(p_a[b][s - 1][t] + px_a[b][s - 1][t],
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// .. which obtains p_a[b][s][t - 1] from a register.
scalar_t term1 = p_a[b][s - 1][t] + px_a[b][s - 1][t],
// term2 = p_a[b][s][t - 1] + py_a[b][s][t - 1], <-- not
// actually needed..
total = p_a[b][s][t],
term1_deriv = exp(term1 - total),
term2_deriv = 1.0 - term1_deriv,
grad = p_grad_a[b][s][t],
term1_grad = term1_deriv * grad,
term2_grad = term2_deriv * grad;
px_grad_a[b][s - 1][t] = term1_grad;
p_grad_a[b][s - 1][t] = term1_grad;
py_grad_a[b][s][t - 1] = term2_grad;
p_grad_a[b][s][t - 1] += term2_grad;
}
}
for (int t = t_end; t > t_begin; --t) {
// Backprop for:
// p_a[b][s_begin][t] = p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1];
scalar_t this_p_grad = p_grad_a[b][s_begin][t];
p_grad_a[b][s_begin][t - 1] += this_p_grad;
py_grad_a[b][s_begin][t - 1] = this_p_grad;
}
for (int s = s_end; s > s_begin; --s) {
// Backprop for:
// p_a[b][s][t_begin] = p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
scalar_t this_p_grad = p_grad_a[b][s][t_begin];
p_grad_a[b][s - 1][t_begin] += this_p_grad;
px_grad_a[b][s - 1][t_begin] = this_p_grad;
}
// There is no backprop for:
// p_a[b][s_begin][t_begin] = 0.0;
// .. but we can use this for a check, that the grad at the beginning
// of the sequence is equal to the grad at the end of the sequence.
if (ans_grad_a[b] != 0.0) {
float grad_ratio = p_grad_a[b][s_begin][t_begin] / ans_grad_a[b];
if (fabs(grad_ratio - 1.0) > 0.01) {
printf("Warning: mutual_information backprop: expected these numbers to be the same: %f vs. %f\n",
(float)p_grad_a[b][s_begin][t_begin], (float)ans_grad_a[b]);
}
}
}
}));
// std::cout << "p_grad = " << p_grad;
return std::vector<torch::Tensor>({px_grad, py_grad});
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("mutual_information_cpu", &mutual_information_cpu, "Integrated convolution forward function (CPU)");
m.def("mutual_information_backward_cpu", &mutual_information_backward_cpu, "Integrated convolution backward function (CPU)");
}
#include <torch/extension.h>
/*
Forward of mutual_information. See also """... """ comment of
`mutual_information` in mutual_information.py. This It is the core recursion
in the sequence-to-sequence mutual information computation.
Args:
px: Tensor of shape [B][S][T + 1]; contains the log-odds ratio of
generating the next x in the sequence, i.e.
xy[b][s][t] is the log of
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
i.e. the log-prob of generating x_s given subsequences of lengths
(s, t), divided by the prior probability of generating x_s. (See
mutual_information.py for more info).
py: The log-odds ratio of generating the next y in the sequence.
Shape [B][S + 1][T]
p: This function writes to p[b][s][t] the mutual information between
sub-sequences of x and y of length s and t respectively, from the
b'th sequences in the batch. Its shape is [B][S + 1][T + 1].
Concretely, this function implements the following recursion,
in the case where s_begin == t_begin == 0:
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if s > 0 or t > 0,
treating values with any -1 index as -infinity.
.. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
boundary: If set, a tensor of shape [B][4] of type int64_t, which
contains, where for each batch element b, boundary[b] equals
[s_begin, t_begin, s_end, t_end]
which are the beginning and end (i.e. one-past-the-last) of the
x and y sequences that we should process. Alternatively, may be
a tensor of shape [0][0] and type int64_t; the elements will
default to (0, 0, S, T).
ans: a tensor `ans` of shape [B], where this function will set
ans[b] = p[b][s_end][t_end],
with s_end and t_end being (S, T) if `boundary` was specified,
and (boundary[b][2], boundary[b][3]) otherwise.
`ans` represents the mutual information between each pair of
sequences (i.e. x[b] and y[b], although the sequences are not
supplied directy to this function).
The block-dim and grid-dim must both be 1-dimensional, and the block-dim must
be at least 128.
*/
torch::Tensor mutual_information_cuda(torch::Tensor px, // [B][S][T+1]
torch::Tensor py, // [B][S+1][T]
torch::Tensor boundary, // [B][4], int64_t.
torch::Tensor p); // [B][S+1][T+1]; an output
/*
backward of mutual_information; returns (grad_px, grad_py)
if overwrite_ans_grad == true, this function will overwrite ans_grad with a
value that, if the computation worked correctly, should be identical to or
very close to the value of ans_grad at entry. This can be used
to validate the correctness of this code.
*/
std::vector<torch::Tensor> mutual_information_backward_cuda(
torch::Tensor px,
torch::Tensor py,
torch::Tensor boundary,
torch::Tensor p,
torch::Tensor ans_grad,
bool overwrite_ans_grad);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("mutual_information_cuda", &mutual_information_cuda, "Mutual information forward function (CUDA)");
m.def("mutual_information_backward_cuda", &mutual_information_backward_cuda, "Mutual information backward function (CUDA)");
}
# Caution: this will fail occasionally due to cutoffs not being quite large enough.
# As long as it passes most of the time, it's OK.
import random
import torch
from torch_mutual_information import mutual_information_recursion, joint_mutual_information_recursion
def test_mutual_information_basic():
print("Running test_mutual_information_basic()")
for _iter in range(100):
(B, S, T) = (random.randint(1, 10),
random.randint(1, 200),
random.randint(1, 200))
random_px = (random.random() < 0.2)
random_py = (random.random() < 0.2)
random_boundary = (random.random() < 0.7)
big_px = (random.random() < 0.2)
big_py = (random.random() < 0.2)
print(f"B, S, T = {B}, {S}, {T}, random_px={random_px}, random_py={random_py}, big_px={big_px}, big_py={big_py}, random_boundary={random_boundary}")
for dtype in [torch.float32, torch.float64]:
px_grads = []
py_grads = []
m_vals = []
for device in [ torch.device('cpu'), torch.device('cuda:0') ]:
print("dtype = ", dtype, ", device = ", device)
if random_boundary:
def get_boundary_row():
s_begin = random.randint(0, S - 1)
t_begin = random.randint(0, T - 1)
s_end = random.randint(s_begin, S) # allow empty sequence
t_end = random.randint(t_begin, T) # allow empty sequence
return [s_begin, t_begin, s_end, t_end]
if device == torch.device('cpu'):
boundary = torch.tensor([ get_boundary_row() for _ in range(B) ],
dtype=torch.int64, device=device)
else:
boundary = boundary.to(device)
else:
# Use default boundary, but either specified directly or not.
if random.random() < 0.5:
boundary = torch.tensor([ 0, 0, S, T ], dtype=torch.int64).unsqueeze(0).expand(B, 4).to(device)
else:
boundary = None
if device == torch.device('cpu'):
if random_px:
px = torch.randn(B, S, T + 1, dtype=dtype).to(device) # log of an odds ratio
else:
px = torch.zeros(B, S, T + 1, dtype=dtype).to(device) # log of an odds ratio
# px and py get exponentiated, and then multiplied together up to
# 32 times (BLOCK_SIZE in the CUDA code), so 15 is actually a big number that
# could lead to overflow.
if big_px:
px += 15.0
if random_py:
py = torch.randn(B, S + 1, T, dtype=dtype).to(device) # log of an odds ratio
else:
py = torch.zeros(B, S + 1, T, dtype=dtype).to(device) # log of an odds ratio
if big_py:
py += 15.0
else:
px = px.to(device).detach()
py = py.to(device).detach()
px.requires_grad = True
py.requires_grad = True
#m = mutual_information_recursion(px, py, None)
m = mutual_information_recursion(px, py, boundary)
m2 = joint_mutual_information_recursion((px,), (py,), boundary)
m3 = joint_mutual_information_recursion((px * 0.5, px * 0.5), (py * 0.5, py * 0.5), boundary)
print("m3, before sum, = ", m3)
m3 = m3.sum(dim=0) # it is supposed to be identical only after
# summing over dim 0, corresponding to the
# sequence dim
print("m = ", m, ", size = ", m.shape)
print("m2 = ", m2, ", size = ", m2.shape)
print("m3 = ", m3, ", size = ", m3.shape)
assert torch.allclose(m, m2)
assert torch.allclose(m, m3)
#print("exp(m) = ", m.exp())
# the loop this is in checks that the CPU and CUDA versions give the same
# derivative; by randomizing which of m, m2 or m3 we backprop, we also
# ensure that the joint version of the code gives the same derivative
# as the regular version
scale = 3
if random.random() < 0.5:
(m.sum() * scale).backward()
elif random.random() < 0.5:
(m2.sum() * scale).backward()
else:
(m3.sum() * scale).backward()
#print("px_grad = ", px.grad)
#print("py_grad = ", py.grad)
px_grads.append(px.grad.to('cpu'))
py_grads.append(py.grad.to('cpu'))
m_vals.append(m.to('cpu'))
if not torch.allclose(m_vals[0], m_vals[1], atol=1.0e-02, rtol=1.0e-02):
print(f"m_vals differed CPU vs CUDA: {m_vals[0]} vs. {m_vals[1]}")
assert 0
if not torch.allclose(px_grads[0], px_grads[1], atol=1.0e-02, rtol=1.0e-02):
print(f"px_grads differed CPU vs CUDA: {px_grads[0]} vs. {px_grads[1]}")
assert 0
if not torch.allclose(py_grads[0], py_grads[1], atol=1.0e-02, rtol=1.0e-02):
print(f"py_grads differed CPU vs CUDA: {py_grads[0]} vs. {py_grads[1]}")
assert 0
def test_mutual_information_deriv():
print("Running test_mutual_information_deriv()")
for _iter in range(100):
(B, S, T) = (random.randint(1, 10),
random.randint(1, 200),
random.randint(1, 200))
random_px = (random.random() < 0.2)
random_py = (random.random() < 0.2)
random_boundary = (random.random() < 0.7)
big_px = (random.random() < 0.2)
big_py = (random.random() < 0.2)
print(f"B, S, T = {B}, {S}, {T}, random_px={random_px}, random_py={random_py}, big_px={big_px}, big_py={big_py}, random_boundary={random_boundary}")
for dtype in [torch.float32, torch.float64]:
#px_grads = []
#py_grads = []
#m_vals = []
for device in [ torch.device('cpu'), torch.device('cuda:0') ]:
print("dtype = ", dtype, ", device = ", device)
if random_boundary:
def get_boundary_row():
s_begin = random.randint(0, S - 1)
t_begin = random.randint(0, T - 1)
s_end = random.randint(s_begin + 1, S)
t_end = random.randint(t_begin + 1, T)
return [s_begin, t_begin, s_end, t_end]
if device == torch.device('cpu'):
boundary = torch.tensor([ get_boundary_row() for _ in range(B) ],
dtype=torch.int64, device=device)
else:
boundary = boundary.to(device)
else:
# Use default boundary, but either specified directly or not.
if random.random() < 0.5:
boundary = torch.tensor([ 0, 0, S, T ], dtype=torch.int64).unsqueeze(0).expand(B, 4).to(device)
else:
boundary = None
if device == torch.device('cpu'):
if random_px:
px = torch.randn(B, S, T + 1, dtype=dtype).to(device) # log of an odds ratio
else:
px = torch.zeros(B, S, T + 1, dtype=dtype).to(device) # log of an odds ratio
# px and py get exponentiated, and then multiplied together up to
# 32 times (BLOCK_SIZE in the CUDA code), so 15 is actually a big number that
# could lead to overflow.
if big_px:
px += 15.0
if random_py:
py = torch.randn(B, S + 1, T, dtype=dtype).to(device) # log of an odds ratio
else:
py = torch.zeros(B, S + 1, T, dtype=dtype).to(device) # log of an odds ratio
if big_py:
py += 15.0
else:
px = px.to(device).detach()
py = py.to(device).detach()
px.requires_grad = True
py.requires_grad = True
m = mutual_information_recursion(px, py, boundary)
#print("m = ", m)
#print("exp(m) = ", m.exp())
#print("px_grad = ", px.grad)
#print("py_grad = ", py.grad)
#px_grads.append(px.grad.to('cpu'))
#py_grads.append(py.grad.to('cpu'))
#m_vals.append(m.to('cpu'))
m_grad = torch.randn(B, dtype=dtype, device=device)
m.backward(gradient=m_grad)
delta = 1.0e-04
delta_px = delta * torch.randn_like(px)
m2 = mutual_information_recursion(px + delta_px, py, boundary)
delta_m = m2 - m
observed_delta = (delta_m * m_grad).sum().to('cpu')
predicted_delta = (delta_px * px.grad).sum().to('cpu')
print(f"For px: observed,predicted objf changes are: {observed_delta},{predicted_delta}, absolute objf was {(m * m_grad).sum()}")
atol = 1.0e-02 if dtype == torch.float32 else 1.0e-04
rtol = 1.0e-02 if dtype == torch.float32 else 1.0e-04
if not torch.allclose(observed_delta, predicted_delta, atol=atol, rtol=rtol):
print(f"Error: observed and predicted delta too different.")
assert 0
delta_py = delta * torch.randn_like(py)
m2 = mutual_information_recursion(px, py + delta_py, boundary)
delta_m = m2 - m
observed_delta = (delta_m * m_grad).sum().to('cpu')
predicted_delta = (delta_py * py.grad).sum().to('cpu')
print(f"For py: observed,predicted objf changes are: {observed_delta},{predicted_delta}, absolute objf was {(m * m_grad).sum()}")
# if not torch.allclose(m_vals[0], m_vals[1], atol=1.0e-02, rtol=1.0e-02):
# print(f"m_vals differed CPU vs CUDA: {m_vals[0]} vs. {m_vals[1]}")
# assert 0
# if not torch.allclose(px_grads[0], px_grads[1], atol=1.0e-02, rtol=1.0e-02):
# print(f"px_grads differed CPU vs CUDA: {px_grads[0]} vs. {px_grads[1]}")
# assert 0
# if not torch.allclose(py_grads[0], py_grads[1], atol=1.0e-02, rtol=1.0e-02):
# print(f"py_grads differed CPU vs CUDA: {py_grads[0]} vs. {py_grads[1]}")
# assert 0
if __name__ == "__main__":
#torch.set_printoptions(edgeitems=30)
test_mutual_information_basic()
test_mutual_information_deriv()
import os
import torch
from torch import Tensor
from typing import Tuple, Optional
from . mutual_information import mutual_information_recursion, joint_mutual_information_recursion
def get_rnnt_logprobs(lm: Tensor,
am: Tensor,
symbols: Tensor,
termination_symbol: int) -> Tuple[Tensor, Tensor]:
"""
Reduces RNN-T problem (the simple case, where joiner network is just addition),
to a compact, standard form that can then be given
(with boundaries) to mutual_information_recursion(). This function is called from
rnnt_loss_simple(), but may be useful for other purposes.
Args:
lm: Language model part of un-normalized logprobs of symbols, to be added to
acoustic model part before normalizing. Of shape:
[B][S+1][C]
where B is the batch size, S is the maximum sequence length of
the symbol sequence, possibly including the EOS symbol; and
C is size of the symbol vocabulary, including the termination/next-frame
symbol.
Conceptually, lm[b][s] is a vector of length [C] representing the
"language model" part of the un-normalized logprobs of symbols,
given all symbols *earlier than* s in the sequence. The reason
we still need this for position S is that we may still be emitting
the termination/next-frame symbol at this point.
am: Acoustic-model part of un-normalized logprobs of symbols, to be added
to language-model part before normalizing. Of shape:
[B][T][C]
where B is the batch size, T is the maximum sequence length of
the acoustic sequences (in frames); and C is size of the symbol
vocabulary, including the termination/next-frame symbol. It reflects
the "acoustic" part of the probability of any given symbol appearing
next on this frame.
symbols: A LongTensor of shape [B][S], containing the symbols at each position
of the sequence, possibly including EOS
termination_symbol: The identity of the termination symbol, must be
in {0..C-1}
Returns: (px, py) (the names are quite arbitrary).
px: logprobs, of shape [B][S][T+1]
py: logprobs, of shape [B][S+1][T]
in the recursion:
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences of
length s and t respectively. px[b][s][t] represents the probability of
extending the subsequences of length (s,t) by one in the s direction,
given the particular symbol, and py[b][s][t] represents the probability
of extending the subsequences of length (s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
px[:,:,T] equals -infinity, meaning on the "one-past-the-last" frame
we cannot emit any symbols. This is simply a way of incorporating
the probability of the termination symbol on the last frame.
"""
assert lm.ndim== 3 and am.ndim == 3 and lm.shape[0] == am.shape[0] and lm.shape[2] == am.shape[2]
(B, T, C) = am.shape
S = lm.shape[1] - 1
assert symbols.shape == (B, S)
# subtracting am_max and lm_max is to ensure the probs are in a good range to do exp()
# without causing underflow or overflow.
am_max, _ = torch.max(am, dim=2, keepdim=True) # am_max: [B][T][1]
lm_max, _ = torch.max(lm, dim=2, keepdim=True) # lm_max: [B][S+1][1]
am_probs = (am - am_max).exp()
lm_probs = (lm - lm_max).exp()
# normalizers: [B][S+1][T]
normalizers = (torch.matmul(lm_probs, am_probs.transpose(1, 2)) + 1.0e-20).log()
# add lm_max and am_max to normalizers, to make it as if we had not
# subtracted am_max and lm_max above.
normalizers = normalizers + lm_max + am_max.transpose(1, 2) # [B][S+1][T]
# px is the probs of the actual symbols..
px_am = torch.gather(am.unsqueeze(1).expand(B, S, T, C), dim=3,
index=symbols.reshape(B, S, 1, 1).expand(B, S, T, 1)).squeeze(-1) # [B][S][T]
px_am = torch.cat((px_am,
torch.full((B, S, 1), float('-inf'),
device=px_am.device, dtype=px_am.dtype)),
dim=2) # now: [B][S][T+1], index [:,:,T] has -inf..
px_lm = torch.gather(lm[:,:S], dim=2, index=symbols.unsqueeze(-1)) # [B][S][1]
px = px_am + px_lm # [B][S][T+1], last slice indexed [:,:,T] is -inf
px[:,:,:T] -= normalizers[:,:S,:] # px: [B][S][T+1]
# py is the probs of termination symbols, of shape [B][S+1][T]
py_am = am[:,:,termination_symbol].unsqueeze(1) # [B][1][T]
py_lm = lm[:,:,termination_symbol].unsqueeze(2) # [B][S+1][1]
py = py_am + py_lm - normalizers
return (px, py)
def rnnt_loss_simple(lm: Tensor,
am: Tensor,
symbols: Tensor,
termination_symbol: int,
boundary: Tensor = None) -> Tensor:
"""
A simple case of the RNN-T loss, where the 'joiner' network is just addition.
Returns negated total loss value.
Args:
lm: language-model part of unnormalized log-probs of symbols, with shape
(B, S+1, C), i.e. batch, symbol_seq_len+1, num_classes
am: acoustic-model part of unnormalized log-probs of symbols, with shape
(B, T, C), i.e. batch, frame, num_classes
symbols: the symbol sequences, a LongTensor of shape [B][S], and elements in {0..C-1}.
termination_symbol: the termination symbol, with 0 <= termination_symbol < C
boundary: a LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as [0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
Returns:
a Tensor of shape (B,), containing the NEGATED total RNN-T loss values
for each element of the batch (like log-probs of sequences).
"""
px, py = get_rnnt_logprobs(lm, am, symbols, termination_symbol)
return mutual_information_recursion(px, py, boundary)
def get_rnnt_logprobs_aux(lm: Tensor,
am: Tensor,
symbols: Tensor,
termination_symbol: int,
lm_only_scale: float = 0.1,
am_only_scale: float = 0.1) -> Tuple[Tensor, Tensor]:
"""
Reduces RNN-T problem (the simple case, where joiner network is just addition),
to a compact, standard form that can then be given
(with boundaries) to mutual_information_recursion(). This version allows you
to make the loss-function one of the form:
lm_only_scale * lm_probs +
am_only_scale * am_probs +
(1-lm_only_scale-am_only_scale) * combined_probs
where lm_probs and am_probs are the probabilities given the lm and acoustic model
independently.
This function is called from
rnnt_loss_aux(), but may be useful for other purposes.
Args:
lm: Language model part of un-normalized logprobs of symbols, to be added to
acoustic model part before normalizing. Of shape:
[B][S+1][C]
where B is the batch size, S is the maximum sequence length of
the symbol sequence, possibly including the EOS symbol; and
C is size of the symbol vocabulary, including the termination/next-frame
symbol.
Conceptually, lm[b][s] is a vector of length [C] representing the
"language model" part of the un-normalized logprobs of symbols,
given all symbols *earlier than* s in the sequence. The reason
we still need this for position S is that we may still be emitting
the termination/next-frame symbol at this point.
am: Acoustic-model part of un-normalized logprobs of symbols, to be added
to language-model part before normalizing. Of shape:
[B][T][C]
where B is the batch size, T is the maximum sequence length of
the acoustic sequences (in frames); and C is size of the symbol
vocabulary, including the termination/next-frame symbol. It reflects
the "acoustic" part of the probability of any given symbol appearing
next on this frame.
symbols: A LongTensor of shape [B][S], containing the symbols at each position
of the sequence, possibly including EOS
termination_symbol: The identity of the termination symbol, must be
in {0..C-1}
Returns: (px, py) (the names are quite arbitrary).
px: logprobs, of shape [B][S][T+1]
py: logprobs, of shape [B][S+1][T]
in the recursion:
p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
.. where p[b][s][t] is the "joint score" of the pair of subsequences of
length s and t respectively. px[b][s][t] represents the probability of
extending the subsequences of length (s,t) by one in the s direction,
given the particular symbol, and py[b][s][t] represents the probability
of extending the subsequences of length (s,t) by one in the t direction,
i.e. of emitting the termination/next-frame symbol.
px[:,:,T] equals -infinity, meaning on the "one-past-the-last" frame
we cannot emit any symbols. This is simply a way of incorporating
the probability of the termination symbol on the last frame.
"""
assert lm.ndim== 3 and am.ndim == 3 and lm.shape[0] == am.shape[0] and lm.shape[2] == am.shape[2]
(B, T, C) = am.shape
S = lm.shape[1] - 1
assert symbols.shape == (B, S)
# Caution: some parts of this code are a little less clear than they could
# be due to optimizations. In particular it may not be totally obvious that
# all of the logprobs here are properly normalized. We test that
# this code is invariant to adding constants in the appropriate ways.
# subtracting am_max and lm_max is to ensure the probs are in a good range to do exp()
# without causing underflow or overflow.
am_max, _ = torch.max(am, dim=2, keepdim=True) # am_max: [B][T][1]
lm_max, _ = torch.max(lm, dim=2, keepdim=True) # lm_max: [B][S+1][1]
am_probs = (am - am_max).exp() # [B][T][C]
lm_probs = (lm - lm_max).exp() # [B][S+1][C]
# normalizers: [B][S+1][T]
normalizers = (torch.matmul(lm_probs, am_probs.transpose(1, 2)) + 1.0e-20).log()
# normalizer per frame, if we take only the LM probs by themselves
lmonly_normalizers = lm_probs.sum(dim=2, keepdim=True) # lmonly_normalizers: [B][S+1][1]
unigram_lm = torch.mean(lm_probs / lmonly_normalizers, dim=(0,1), keepdim=True) + 1.0e-20 # [1][1][C]
amonly_normalizers = torch.mv(am_probs.reshape(-1, C), unigram_lm.reshape(C)).reshape(B, T, 1).log() + am_max # [B][T][1]
amonly_normalizers = amonly_normalizers.transpose(1, 2) # [B][1][T]
unigram_lm = unigram_lm.log()
lmonly_normalizers = lmonly_normalizers.log() + lm_max # [B][S+1][1], log-normalizer, used for LM-only part of prob.
# add lm_max and am_max to normalizers, to make it as if we had not
# subtracted am_max and lm_max above.
normalizers = normalizers + lm_max + am_max.transpose(1, 2) # [B][S+1][T]
# px is the probs of the actual symbols (not yet normalized)..
px_am = torch.gather(am.unsqueeze(1).expand(B, S, T, C), dim=3,
index=symbols.reshape(B, S, 1, 1).expand(B, S, T, 1)).squeeze(-1) # [B][S][T]
px_am = torch.cat((px_am,
torch.full((B, S, 1), float('-inf'),
device=px_am.device, dtype=px_am.dtype)),
dim=2) # now: [B][S][T+1], index [:,:,T] has -inf..
px_lm = torch.gather(lm[:,:S], dim=2, index=symbols.unsqueeze(-1)) # [B][S][1]
px_lm_unigram = torch.gather(unigram_lm.expand(B, S, C), dim=2, index=symbols.unsqueeze(-1)) # [B][S][1]
px = px_am + px_lm # [B][S][T+1], last slice indexed [:,:,T] is -inf
px[:,:,:T] -= normalizers[:,:S,:] # px: [B][S][T+1]
px_amonly = px_am + px_lm_unigram # [B][S][T+1]
px_amonly[:,:,:T] -= amonly_normalizers
px_lmonly = px_lm - lmonly_normalizers[:,:S,:]
# py is the probs of termination symbols, of shape [B][S+1][T]
py_am = am[:,:,termination_symbol].unsqueeze(1) # [B][1][T]
py_lm = lm[:,:,termination_symbol].unsqueeze(2) # [B][S+1][1]
py = py_am + py_lm - normalizers
py_lm_unigram = unigram_lm[0][0][termination_symbol] # scalar, normalized..
py_amonly = py_am + py_lm_unigram - amonly_normalizers # [B][S+1][T]
py_lmonly = py_lm - lmonly_normalizers # [B][S+1][T]
combined_scale = 1.0 - lm_only_scale - am_only_scale
# We need to avoid exact zeros in the scales because otherwise multiplying -inf
# by zero generates nan.
if lm_only_scale == 0.0:
lm_only_scale = 1.0e-20
if am_only_scale == 0.0:
am_only_scale = 1.0e-20
px_interp = px * combined_scale + px_lmonly * lm_only_scale + px_amonly * am_only_scale
py_interp = py * combined_scale + py_lmonly * lm_only_scale + py_amonly * am_only_scale
print("px_interp = ", px_interp)
print("py_interp = ", py_interp)
return (px_interp, py_interp)
def rnnt_loss_aux(lm: Tensor,
am: Tensor,
symbols: Tensor,
termination_symbol: int,
lm_only_scale: float = 0.1,
am_only_scale: float = 0.1,
boundary: Tensor = None) -> Tensor:
"""
A simple case of the RNN-T loss, where the 'joiner' network is just addition.
Returns negated total loss value.
Args:
lm: language-model part of unnormalized log-probs of symbols, with shape
(B, S+1, C), i.e. batch, symbol_seq_len+1, num_classes.
These are assumed to be well-normalized, in the sense that we could
use them as probabilities separately from the am scores
am: acoustic-model part of unnormalized log-probs of symbols, with shape
(B, T, C), i.e. batch, frame, num_classes
symbols: the symbol sequences, a LongTensor of shape [B][S], and elements in {0..C-1}.
termination_symbol: the termination symbol, with 0 <= termination_symbol < C
am_only_scale: the scale on the "AM-only" part of the loss, for which we use
an "averaged" LM (averaged over all histories, so effectively unigram).
boundary: a LongTensor of shape [B, 4] with elements interpreted as
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as [0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
Returns:
a Tensor of shape (B,), containing the NEGATED total RNN-T loss values
for each element of the batch (like log-probs of sequences).
"""
px, py = get_rnnt_logprobs_aux(lm, am, symbols, termination_symbol,
lm_only_scale, am_only_scale)
return mutual_information_recursion(px, py, boundary)
import random
import torch
from torch_mutual_information import mutual_information_recursion, joint_mutual_information_recursion, get_rnnt_logprobs, rnnt_loss_simple, rnnt_loss_aux
def test_rnnt_logprobs_basic():
print("Running test_rnnt_logprobs_basic()")
B = 1
S = 3
T = 4
C = 3
# lm: [B][S+1][C]
lm = torch.tensor([[[ 0, 0, 1 ], [0, 1, 1], [1, 0, 1], [2, 2, 0]]], dtype=torch.float)
# am: [B][T][C]
am = torch.tensor([[[ 0, 1, 2], [0, 0, 0 ], [0, 2, 4 ], [0, 3, 3]]], dtype=torch.float)
# lm[:] = 0.0
# am[:] = 0.0
termination_symbol = 2
symbols = torch.tensor([[ 0, 1, 0 ] ], dtype=torch.long)
px, py = get_rnnt_logprobs(lm, am, symbols, termination_symbol)
assert px.shape == (B, S, T+1)
assert py.shape == (B, S+1, T)
assert symbols.shape == (B, S)
print("px = ", px)
print("py = ", py)
m = mutual_information_recursion(px, py)
print("m = ", m)
# should be invariant to adding a constant for any frame.
lm += torch.randn(B, S+1, 1)
am += torch.randn(B, T, 1)
m2 = rnnt_loss_simple(lm, am, symbols, termination_symbol, None)
print("m2 = ", m2)
device = torch.device('cuda')
m3 = rnnt_loss_simple(lm.to(device), am.to(device), symbols.to(device), termination_symbol, None)
print("m3 = ", m3)
device = torch.device('cuda')
m4 = rnnt_loss_aux(lm.to(device), am.to(device), symbols.to(device), termination_symbol,
lm_only_scale=0.0, am_only_scale=0.0, boundary=None)
print("m4 = ", m4)
assert torch.allclose(m, m2)
assert torch.allclose(m, m3.to('cpu'))
assert torch.allclose(m, m4.to('cpu'))
def test_rnnt_logprobs_aux():
print("Running test_rnnt_logprobs_aux()")
B = 1
S = 3
T = 4
C = 3
# lm: [B][S+1][C]
lm = torch.tensor([[[ 0, 0, 1 ], [0, 1, 1], [1, 0, 1], [2, 2, 0]]], dtype=torch.float)
# am: [B][T][C]
am = torch.tensor([[[ 0, 1, 2], [0, 0, 0 ], [0, 2, 4 ], [0, 3, 3]]], dtype=torch.float)
termination_symbol = 2
symbols = torch.tensor([[ 0, 1, 0 ] ], dtype=torch.long)
device = torch.device('cuda')
m1 = rnnt_loss_aux(lm.to(device), am.to(device), symbols.to(device), termination_symbol,
lm_only_scale=0.0, am_only_scale=0.333, boundary=None)
print("m1 = ", m1)
# should be invariant to adding a constant for any frame.
lm += torch.randn(B, S+1, 1)
am += torch.randn(B, T, 1)
m2 = rnnt_loss_aux(lm.to(device), am.to(device), symbols.to(device), termination_symbol,
lm_only_scale=0.0, am_only_scale=0.333, boundary=None)
print("m2 = ", m2)
assert torch.allclose(m1, m2)
if __name__ == "__main__":
#torch.set_printoptions(edgeitems=30)
test_rnnt_logprobs_aux()
test_rnnt_logprobs_basic()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment