Unverified Commit 08a71271 authored by gmagogsfm's avatar gmagogsfm Committed by GitHub
Browse files

Switch string formatting to str.format to be TorchScript friendly. (#850)

parent 3bab2b29
......@@ -92,7 +92,7 @@ def decode(fn, sound_path, exe_path, scp_path, out_dir):
'round_to_power_of_two', 'snip_edges', 'subtract_mean', 'use_energy', 'use_log_fbank',
'use_power', 'vtln_high', 'vtln_low', 'vtln_warp', 'window_type']
fn_split = fn.split('-')
assert len(fn_split) == len(arr), ('Len mismatch: %d and %d' % (len(fn_split), len(arr)))
assert len(fn_split) == len(arr), ('Len mismatch: {} and {}'.format(len(fn_split), len(arr)))
inputs = {arr[i]: utils.parse(fn_split[i]) for i in range(len(arr))}
# print flags for C++
......
......@@ -148,7 +148,9 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
sound, sr = torchaudio.load_wav(sound_filepath)
files = self.test_filepaths[filepath_key]
assert len(files) == expected_num_files, ('number of kaldi %s file changed to %d' % (filepath_key, len(files)))
assert len(files) == expected_num_files, \
('number of kaldi {} file changed to {}'.format(
filepath_key, len(files)))
for f in files:
print(f)
......
......@@ -135,13 +135,15 @@ def _get_waveform_and_window_properties(waveform: Tensor,
r"""Gets the waveform and window properties
"""
channel = max(channel, 0)
assert channel < waveform.size(0), ('Invalid channel %d for size %d' % (channel, waveform.size(0)))
assert channel < waveform.size(0), ('Invalid channel {} for size {}'.format(channel, waveform.size(0)))
waveform = waveform[channel, :] # size (n)
window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size
assert 2 <= window_size <= len(waveform), ('choose a window size %d that is [2, %d]' % (window_size, len(waveform)))
assert 2 <= window_size <= len(
waveform), ('choose a window size {} that is [2, {}]'
.format(window_size, len(waveform)))
assert 0 < window_shift, '`window_shift` must be greater than 0'
assert padded_window_size % 2 == 0, 'the padded `window_size` must be divisible by two.' \
' use `round_to_power_of_two` or change `frame_length`'
......@@ -430,7 +432,7 @@ def get_mel_banks(num_bins: int,
high_freq += nyquist
assert (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq), \
('Bad values in options: low-freq %f and high-freq %f vs. nyquist %f' % (low_freq, high_freq, nyquist))
('Bad values in options: low-freq {} and high-freq {} vs. nyquist {}'.format(low_freq, high_freq, nyquist))
# fft-bin width [think of it as Nyquist-freq / half-window-length]
fft_bin_width = sample_freq / window_length_padded
......@@ -446,8 +448,8 @@ def get_mel_banks(num_bins: int,
assert vtln_warp_factor == 1.0 or ((low_freq < vtln_low < high_freq) and
(0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)), \
('Bad values in options: vtln-low %f and vtln-high %f, versus low-freq %f and high-freq %f' %
(vtln_low, vtln_high, low_freq, high_freq))
('Bad values in options: vtln-low {} and vtln-high {}, versus '
'low-freq {} and high-freq {}'.format(vtln_low, vtln_high, low_freq, high_freq))
bin = torch.arange(num_bins).unsqueeze(1)
left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1)
......
......@@ -149,8 +149,8 @@ def griffinlim(
Returns:
torch.Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
"""
assert momentum < 1, 'momentum=%s > 1 can be unstable' % momentum
assert momentum >= 0, 'momentum=%s < 0' % momentum
assert momentum < 1, 'momentum={} > 1 can be unstable'.format(momentum)
assert momentum >= 0, 'momentum={} < 0'.format(momentum)
# pack batch
shape = specgram.size()
......
......@@ -141,8 +141,8 @@ class GriffinLim(torch.nn.Module):
rand_init: bool = True) -> None:
super(GriffinLim, self).__init__()
assert momentum < 1, 'momentum=%s > 1 can be unstable' % momentum
assert momentum > 0, 'momentum=%s < 0' % momentum
assert momentum < 1, 'momentum={} > 1 can be unstable'.format(momentum)
assert momentum > 0, 'momentum={} < 0'.format(momentum)
self.n_fft = n_fft
self.n_iter = n_iter
......@@ -237,7 +237,7 @@ class MelScale(torch.nn.Module):
self.f_max = f_max if f_max is not None else float(sample_rate // 2)
self.f_min = f_min
assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)
assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
......@@ -313,7 +313,7 @@ class InverseMelScale(torch.nn.Module):
self.tolerance_change = tolerance_change
self.sgdargs = sgdargs or {'lr': 0.1, 'momentum': 0.9}
assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)
assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
fb = F.create_fb_matrix(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
self.register_buffer('fb', fb)
......@@ -607,7 +607,7 @@ class Resample(torch.nn.Module):
return waveform
raise ValueError('Invalid resampling method: %s' % (self.resampling_method))
raise ValueError('Invalid resampling method: {}'.format(self.resampling_method))
class ComplexNorm(torch.nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment