Unverified Commit 30de797c authored by moto's avatar moto Committed by GitHub
Browse files

[BC-Breaking] Remove kaldi.resample_waveform (#1555)

`torchaudio.compliance.kaldi.resample_waveform` has been replaced with `torchaudio.funcitonal.resample`.
parent 5432a3f5
import argparse
import logging
import os
import random
import subprocess
import torch
import torchaudio
import utils
from torchaudio_unittest import common_utils
def run(exe_path, scp_path, out_dir, wave_len, num_outputs, remove_files, log_level):
logging.basicConfig(level=log_level)
for _ in range(num_outputs):
try:
nyquist = 16000 // 2
high_freq = random.randint(1, nyquist)
low_freq = random.randint(0, high_freq - 1)
vtln_low = random.randint(low_freq + 1, high_freq - 1)
vtln_high = random.randint(vtln_low + 1, high_freq - 1)
vtln_warp_factor = random.uniform(0.0, 10.0) if random.random() < 0.3 else 1.0
except Exception:
continue
if not ((0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq)):
continue
if not (vtln_warp_factor == 1.0 or ((low_freq < vtln_low < high_freq) and
(0.0 < vtln_high < high_freq) and (vtln_low < vtln_high))):
continue
inputs = {
'blackman_coeff': '%.4f' % (random.random() * 5),
'energy_floor': '%.4f' % (random.random() * 5),
'frame_length': '%.4f' % (float(random.randint(3, wave_len - 1)) / 16000 * 1000),
'frame_shift': '%.4f' % (float(random.randint(1, wave_len - 1)) / 16000 * 1000),
'high_freq': str(high_freq),
'htk_compat': utils.generate_rand_boolean(),
'low_freq': str(low_freq),
'num_mel_bins': str(random.randint(4, 8)),
'preemphasis_coefficient': '%.2f' % random.random(),
'raw_energy': utils.generate_rand_boolean(),
'remove_dc_offset': utils.generate_rand_boolean(),
'round_to_power_of_two': utils.generate_rand_boolean(),
'snip_edges': utils.generate_rand_boolean(),
'subtract_mean': utils.generate_rand_boolean(),
'use_energy': utils.generate_rand_boolean(),
'use_log_fbank': utils.generate_rand_boolean(),
'use_power': utils.generate_rand_boolean(),
'vtln_high': str(vtln_high),
'vtln_low': str(vtln_low),
'vtln_warp': '%.4f' % (vtln_warp_factor),
'window_type': utils.generate_rand_window_type()
}
fn = 'fbank-' + ('-'.join(list(inputs.values())))
out_fn = out_dir + fn + '.ark'
arg = [exe_path]
arg += ['--' + k.replace('_', '-') + '=' + inputs[k] for k in inputs]
arg += ['--dither=0.0', scp_path, out_fn]
logging.info(fn)
logging.info(inputs)
logging.info(' '.join(arg))
try:
if log_level == 'INFO':
subprocess.call(arg)
else:
subprocess.call(arg, stderr=open(os.devnull, 'wb'), stdout=open(os.devnull, 'wb'))
logging.info('success')
except Exception:
if remove_files and os.path.exists(out_fn):
os.remove(out_fn)
def decode(fn, sound_path, exe_path, scp_path, out_dir):
"""
Takes a filepath and prints out the corresponding shell command to run that specific
kaldi configuration. It also calls compliance.kaldi and prints the two outputs.
Example:
>> fn = 'fbank-1.1009-2.5985-1.1875-0.8750-5723-true-918-4-0.31-true-false-true-true-' \
'false-false-false-true-4595-4281-1.0000-hamming.ark'
>> decode(fn)
"""
out_fn = out_dir + fn
fn = fn[len('fbank-'):-len('.ark')]
arr = [
'blackman_coeff', 'energy_floor', 'frame_length', 'frame_shift', 'high_freq', 'htk_compat',
'low_freq', 'num_mel_bins', 'preemphasis_coefficient', 'raw_energy', 'remove_dc_offset',
'round_to_power_of_two', 'snip_edges', 'subtract_mean', 'use_energy', 'use_log_fbank',
'use_power', 'vtln_high', 'vtln_low', 'vtln_warp', 'window_type']
fn_split = fn.split('-')
assert len(fn_split) == len(arr), ('Len mismatch: {} and {}'.format(len(fn_split), len(arr)))
inputs = {arr[i]: utils.parse(fn_split[i]) for i in range(len(arr))}
# print flags for C++
s = ' '.join(['--' + arr[i].replace('_', '-') + '=' + fn_split[i] for i in range(len(arr))])
logging.info(exe_path + ' --dither=0.0 --debug-mel=true ' + s + ' ' + scp_path + ' ' + out_fn)
logging.info()
# print args for python
inputs['dither'] = 0.0
logging.info(inputs)
sound, sample_rate = common_utils.load_wav(sound_path, normalize=False)
kaldi_output_dict = {k: v for k, v in torchaudio.kaldi_io.read_mat_ark(out_fn)}
res = torchaudio.compliance.kaldi.fbank(sound, **inputs)
torch.set_printoptions(precision=10, sci_mode=False)
logging.info(res)
logging.info(kaldi_output_dict['my_id'])
if __name__ == '__main__':
""" Examples:
>> python test/compliance/generate_fbank_data.py \
--exe_path=/scratch/jamarshon/kaldi/src/featbin/compute-fbank-feats \
--scp_path=scp:/scratch/jamarshon/downloads/a.scp \
--out_dir=ark:/scratch/jamarshon/audio/test/assets/kaldi/
>> python test/compliance/generate_fbank_data.py \
--exe_path=/scratch/jamarshon/kaldi/src/featbin/compute-fbank-feats \
--scp_path=scp:/scratch/jamarshon/downloads/a.scp \
--out_dir=ark:/scratch/jamarshon/audio/test/assets/kaldi/ \
--decode=true \
--sound_path=/scratch/jamarshon/audio/test/assets/kaldi_file.wav \
--fn="fbank-1.1009-2.5985-1.1875-0.8750-5723-true-918-4-0.31-true-false-true-
true-false-false-false-true-4595-4281-1.0000-hamming.ark"
"""
parser = argparse.ArgumentParser(description='Generate fbank data using Kaldi.')
parser.add_argument('--exe_path', type=str, required=True, help='Path to the compute-fbank-feats executable.')
parser.add_argument('--scp_path', type=str, required=True, help='Path to the scp file. An example of its contents would be \
"my_id /scratch/jamarshon/audio/test/assets/kaldi_file.wav". where the space separates an id from a wav file.')
parser.add_argument('--out_dir', type=str, required=True,
help='The directory to which the stft features will be written to.')
# run arguments
parser.add_argument('--wave_len', type=int, default=20,
help='The number of samples inside the input wave file read from `scp_path`')
parser.add_argument('--num_outputs', type=int, default=100, help='How many output files should be generated.')
parser.add_argument('--remove_files', type=bool, default=False,
help='Whether to remove files generated from exception')
parser.add_argument('--log_level', type=str, default='WARNING',
help='Log level (DEBUG|INFO|WARNING|ERROR|CRITICAL)')
# decode arguments
parser.add_argument('--decode', type=bool, default=False, help='Whether to run the decode or run function.')
parser.add_argument('--fn', type=str, help='Filepath to decode.')
parser.add_argument('--sound_path', type=str, help='Sound filepath to decode.')
args = parser.parse_args()
if args.decode:
decode(args.fn, args.sound_path, args.exe_path, args.scp_path, args.out_dir)
else:
run(args.exe_path, args.scp_path, args.out_dir, args.wave_len, args.num_outputs,
args.remove_files, args.log_level)
import argparse
import logging
import os
import random
import subprocess
import utils
def run(exe_path, scp_path, out_dir, wave_len, num_outputs, remove_files, log_level):
logging.basicConfig(level=log_level)
for _ in range(num_outputs):
inputs = {
'blackman_coeff': '%.4f' % (random.random() * 5),
'dither': '0',
'energy_floor': '%.4f' % (random.random() * 5),
'frame_length': '%.4f' % (float(random.randint(2, wave_len - 1)) / 16000 * 1000),
'frame_shift': '%.4f' % (float(random.randint(1, wave_len - 1)) / 16000 * 1000),
'preemphasis_coefficient': '%.2f' % random.random(),
'raw_energy': utils.generate_rand_boolean(),
'remove_dc_offset': utils.generate_rand_boolean(),
'round_to_power_of_two': utils.generate_rand_boolean(),
'snip_edges': utils.generate_rand_boolean(),
'subtract_mean': utils.generate_rand_boolean(),
'window_type': utils.generate_rand_window_type()
}
fn = 'spec-' + ('-'.join(list(inputs.values())))
out_fn = out_dir + fn + '.ark'
arg = [exe_path]
arg += ['--' + k.replace('_', '-') + '=' + inputs[k] for k in inputs]
arg += [scp_path, out_fn]
logging.info(fn)
logging.info(inputs)
logging.info(' '.join(arg))
try:
if log_level == 'INFO':
subprocess.call(arg)
else:
subprocess.call(arg, stderr=open(os.devnull, 'wb'), stdout=open(os.devnull, 'wb'))
logging.info('success')
except Exception:
if remove_files and os.path.exists(out_fn):
os.remove(out_fn)
if __name__ == '__main__':
""" Examples:
>> python test/compliance/generate_test_stft_data.py \
--exe_path=/scratch/jamarshon/kaldi/src/featbin/compute-spectrogram-feats \
--scp_path=scp:/scratch/jamarshon/downloads/a.scp \
--out_dir=ark:/scratch/jamarshon/audio/test/assets/kaldi/
"""
parser = argparse.ArgumentParser(description='Generate spectrogram data using Kaldi.')
parser.add_argument('--exe_path', type=str, required=True, help='Path to the compute-spectrogram-feats executable.')
parser.add_argument('--scp_path', type=str, required=True, help='Path to the scp file. An example of its contents would be \
"my_id /scratch/jamarshon/audio/test/assets/kaldi_file.wav". where the space separates an id from a wav file.')
parser.add_argument('--out_dir', type=str, required=True,
help='The directory to which the stft features will be written to.')
# run arguments
parser.add_argument('--wave_len', type=int, default=20,
help='The number of samples inside the input wave file read from `scp_path`')
parser.add_argument('--num_outputs', type=int, default=100, help='How many output files should be generated.')
parser.add_argument('--remove_files', type=bool, default=False,
help='Whether to remove files generated from exception')
parser.add_argument('--log_level', type=str, default='WARNING',
help='Log level (DEBUG|INFO|WARNING|ERROR|CRITICAL)')
args = parser.parse_args()
run(args.exe_path, args.scp_path, args.out_dir, args.wave_len, args.num_outputs,
args.remove_files, args.log_level)
import random
import torchaudio
TEST_PREFIX = ['spec', 'fbank', 'mfcc', 'resample']
def generate_rand_boolean():
# Generates a random boolean ('true', 'false')
return 'true' if random.randint(0, 1) else 'false'
def generate_rand_window_type():
# Generates a random window type
return torchaudio.compliance.kaldi.WINDOWS[random.randint(0, len(torchaudio.compliance.kaldi.WINDOWS) - 1)]
def parse(token):
# converts an arg extracted from filepath to its corresponding python type
if token == 'true':
return True
if token == 'false':
return False
if token in torchaudio.compliance.kaldi.WINDOWS or token in TEST_PREFIX:
return token
if '.' in token:
return float(token)
return int(token)
import os
import math
import torch import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi import torchaudio.compliance.kaldi as kaldi
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
from .compliance import utils as compliance_utils
def extract_window(window, wave, f, frame_length, frame_shift, snip_edges): def extract_window(window, wave, f, frame_length, frame_shift, snip_edges):
...@@ -45,21 +40,8 @@ def extract_window(window, wave, f, frame_length, frame_shift, snip_edges): ...@@ -45,21 +40,8 @@ def extract_window(window, wave, f, frame_length, frame_shift, snip_edges):
window[f, s] = wave[s_in_wave] window[f, s] = wave[s_in_wave]
@common_utils.skipIfNoSox
class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase): class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
kaldi_output_dir = common_utils.get_asset_path('kaldi')
test_filepath = common_utils.get_asset_path('kaldi_file.wav')
test_filepaths = {prefix: [] for prefix in compliance_utils.TEST_PREFIX}
# separating test files by their types (e.g 'spec', 'fbank', etc.)
for f in os.listdir(kaldi_output_dir):
dash_idx = f.find('-')
assert f.endswith('.ark') and dash_idx != -1
key = f[:dash_idx]
assert key in test_filepaths
test_filepaths[key].append(f)
def _test_get_strided_helper(self, num_samples, window_size, window_shift, snip_edges): def _test_get_strided_helper(self, num_samples, window_size, window_shift, snip_edges):
waveform = torch.arange(num_samples).float() waveform = torch.arange(num_samples).float()
output = kaldi._get_strided(waveform, window_size, window_shift, snip_edges) output = kaldi._get_strided(waveform, window_size, window_shift, snip_edges)
...@@ -89,76 +71,6 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase): ...@@ -89,76 +71,6 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
for snip_edges in range(0, 2): for snip_edges in range(0, 2):
self._test_get_strided_helper(num_samples, window_size, window_shift, snip_edges) self._test_get_strided_helper(num_samples, window_size, window_shift, snip_edges)
def _create_data_set(self):
# used to generate the dataset to test on. this is not used in testing (offline procedure)
sr = 16000
x = torch.arange(0, 20).float()
# between [-6,6]
y = torch.cos(2 * math.pi * x) + 3 * torch.sin(math.pi * x) + 2 * torch.cos(x)
# between [-2^30, 2^30]
y = (y / 6 * (1 << 30)).long()
# clear the last 16 bits because they aren't used anyways
y = ((y >> 16) << 16).float()
torchaudio.save(self.test_filepath, y, sr)
sound, sample_rate = common_utils.load_wav(self.test_filepath, normalize=False)
print(y >> 16)
self.assertTrue(sample_rate == sr)
self.assertEqual(y, sound)
def _print_diagnostic(self, output, expect_output):
# given an output and expected output, it will print the absolute/relative errors (max and mean squared)
abs_error = output - expect_output
abs_mse = abs_error.pow(2).sum() / output.numel()
abs_max_error = torch.max(abs_error.abs())
relative_error = abs_error / expect_output
relative_mse = relative_error.pow(2).sum() / output.numel()
relative_max_error = torch.max(relative_error.abs())
print('abs_mse:', abs_mse.item(), 'abs_max_error:', abs_max_error.item())
print('relative_mse:', relative_mse.item(), 'relative_max_error:', relative_max_error.item())
def _compliance_test_helper(self, sound_filepath, filepath_key, expected_num_files,
expected_num_args, get_output_fn, atol=1e-5, rtol=1e-7):
"""
Inputs:
sound_filepath (str): The location of the sound file
filepath_key (str): A key to `test_filepaths` which matches which files to use
expected_num_files (int): The expected number of kaldi files to read
expected_num_args (int): The expected number of arguments used in a kaldi configuration
get_output_fn (Callable[[Tensor, List], Tensor]): A function that takes in a sound signal
and a configuration and returns an output
atol (float): absolute tolerance
rtol (float): relative tolerance
"""
sound, sr = common_utils.load_wav(sound_filepath, normalize=False)
files = self.test_filepaths[filepath_key]
assert len(files) == expected_num_files, \
('number of kaldi {} file changed to {}'.format(
filepath_key, len(files)))
for f in files:
print(f)
# Read kaldi's output from file
kaldi_output_path = os.path.join(self.kaldi_output_dir, f)
kaldi_output_dict = dict(torchaudio.kaldi_io.read_mat_ark(kaldi_output_path))
assert len(kaldi_output_dict) == 1 and 'my_id' in kaldi_output_dict, 'invalid test kaldi ark file'
kaldi_output = kaldi_output_dict['my_id']
# Construct the same configuration used by kaldi
args = f.split('-')
args[-1] = os.path.splitext(args[-1])[0]
assert len(args) == expected_num_args, 'invalid test kaldi file name'
args = [compliance_utils.parse(arg) for arg in args]
output = get_output_fn(sound, args)
self._print_diagnostic(output, kaldi_output)
self.assertEqual(output, kaldi_output, atol=atol, rtol=rtol)
def test_mfcc_empty(self): def test_mfcc_empty(self):
# Passing in an empty tensor should result in an error # Passing in an empty tensor should result in an error
self.assertRaises(AssertionError, kaldi.mfcc, torch.empty(0)) self.assertRaises(AssertionError, kaldi.mfcc, torch.empty(0))
...@@ -6,7 +6,6 @@ from torch import Tensor ...@@ -6,7 +6,6 @@ from torch import Tensor
import torchaudio import torchaudio
import torchaudio._internal.fft import torchaudio._internal.fft
from torchaudio._internal.module_utils import deprecated
__all__ = [ __all__ = [
'get_mel_banks', 'get_mel_banks',
...@@ -19,7 +18,6 @@ __all__ = [ ...@@ -19,7 +18,6 @@ __all__ = [
'mfcc', 'mfcc',
'vtln_warp_freq', 'vtln_warp_freq',
'vtln_warp_mel_freq', 'vtln_warp_mel_freq',
'resample_waveform',
] ]
# numeric_limits<float>::epsilon() 1.1920928955078125e-07 # numeric_limits<float>::epsilon() 1.1920928955078125e-07
...@@ -751,32 +749,3 @@ def mfcc( ...@@ -751,32 +749,3 @@ def mfcc(
feature = _subtract_column_mean(feature, subtract_mean) feature = _subtract_column_mean(feature, subtract_mean)
return feature return feature
@deprecated("Please use `torchaudio.functional.resample`.", "0.10")
def resample_waveform(waveform: Tensor,
orig_freq: float,
new_freq: float,
lowpass_filter_width: int = 6,
rolloff: float = 0.99,
resampling_method: str = "sinc_interpolation") -> Tensor:
r"""Resamples the waveform at the new frequency.
This is a wrapper around ``torchaudio.functional.resample``.
Args:
waveform (Tensor): The input signal of size (..., time)
orig_freq (float): The original frequency of the signal
new_freq (float): The desired frequency
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
resampling_method (str, optional): The resampling method to use.
Options: [``sinc_interpolation``, ``kaiser_window``] (Default: ``'sinc_interpolation'``)
Returns:
Tensor: The waveform at the new frequency
"""
return torchaudio.functional.resample(waveform, orig_freq, new_freq, lowpass_filter_width,
rolloff, resampling_method)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment