Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
08a71271
Unverified
Commit
08a71271
authored
Aug 03, 2020
by
gmagogsfm
Committed by
GitHub
Aug 03, 2020
Browse files
Switch string formatting to str.format to be TorchScript friendly. (#850)
parent
3bab2b29
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
18 additions
and
14 deletions
+18
-14
test/compliance/generate_fbank_data.py
test/compliance/generate_fbank_data.py
+1
-1
test/test_compliance_kaldi.py
test/test_compliance_kaldi.py
+3
-1
torchaudio/compliance/kaldi.py
torchaudio/compliance/kaldi.py
+7
-5
torchaudio/functional.py
torchaudio/functional.py
+2
-2
torchaudio/transforms.py
torchaudio/transforms.py
+5
-5
No files found.
test/compliance/generate_fbank_data.py
View file @
08a71271
...
...
@@ -92,7 +92,7 @@ def decode(fn, sound_path, exe_path, scp_path, out_dir):
'round_to_power_of_two'
,
'snip_edges'
,
'subtract_mean'
,
'use_energy'
,
'use_log_fbank'
,
'use_power'
,
'vtln_high'
,
'vtln_low'
,
'vtln_warp'
,
'window_type'
]
fn_split
=
fn
.
split
(
'-'
)
assert
len
(
fn_split
)
==
len
(
arr
),
(
'Len mismatch:
%d
and
%d'
%
(
len
(
fn_split
),
len
(
arr
)))
assert
len
(
fn_split
)
==
len
(
arr
),
(
'Len mismatch:
{}
and
{}'
.
format
(
len
(
fn_split
),
len
(
arr
)))
inputs
=
{
arr
[
i
]:
utils
.
parse
(
fn_split
[
i
])
for
i
in
range
(
len
(
arr
))}
# print flags for C++
...
...
test/test_compliance_kaldi.py
View file @
08a71271
...
...
@@ -148,7 +148,9 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
sound
,
sr
=
torchaudio
.
load_wav
(
sound_filepath
)
files
=
self
.
test_filepaths
[
filepath_key
]
assert
len
(
files
)
==
expected_num_files
,
(
'number of kaldi %s file changed to %d'
%
(
filepath_key
,
len
(
files
)))
assert
len
(
files
)
==
expected_num_files
,
\
(
'number of kaldi {} file changed to {}'
.
format
(
filepath_key
,
len
(
files
)))
for
f
in
files
:
print
(
f
)
...
...
torchaudio/compliance/kaldi.py
View file @
08a71271
...
...
@@ -135,13 +135,15 @@ def _get_waveform_and_window_properties(waveform: Tensor,
r
"""Gets the waveform and window properties
"""
channel
=
max
(
channel
,
0
)
assert
channel
<
waveform
.
size
(
0
),
(
'Invalid channel
%d
for size
%d'
%
(
channel
,
waveform
.
size
(
0
)))
assert
channel
<
waveform
.
size
(
0
),
(
'Invalid channel
{}
for size
{}'
.
format
(
channel
,
waveform
.
size
(
0
)))
waveform
=
waveform
[
channel
,
:]
# size (n)
window_shift
=
int
(
sample_frequency
*
frame_shift
*
MILLISECONDS_TO_SECONDS
)
window_size
=
int
(
sample_frequency
*
frame_length
*
MILLISECONDS_TO_SECONDS
)
padded_window_size
=
_next_power_of_2
(
window_size
)
if
round_to_power_of_two
else
window_size
assert
2
<=
window_size
<=
len
(
waveform
),
(
'choose a window size %d that is [2, %d]'
%
(
window_size
,
len
(
waveform
)))
assert
2
<=
window_size
<=
len
(
waveform
),
(
'choose a window size {} that is [2, {}]'
.
format
(
window_size
,
len
(
waveform
)))
assert
0
<
window_shift
,
'`window_shift` must be greater than 0'
assert
padded_window_size
%
2
==
0
,
'the padded `window_size` must be divisible by two.'
\
' use `round_to_power_of_two` or change `frame_length`'
...
...
@@ -430,7 +432,7 @@ def get_mel_banks(num_bins: int,
high_freq
+=
nyquist
assert
(
0.0
<=
low_freq
<
nyquist
)
and
(
0.0
<
high_freq
<=
nyquist
)
and
(
low_freq
<
high_freq
),
\
(
'Bad values in options: low-freq
%f
and high-freq
%f
vs. nyquist
%f'
%
(
low_freq
,
high_freq
,
nyquist
))
(
'Bad values in options: low-freq
{}
and high-freq
{}
vs. nyquist
{}'
.
format
(
low_freq
,
high_freq
,
nyquist
))
# fft-bin width [think of it as Nyquist-freq / half-window-length]
fft_bin_width
=
sample_freq
/
window_length_padded
...
...
@@ -446,8 +448,8 @@ def get_mel_banks(num_bins: int,
assert
vtln_warp_factor
==
1.0
or
((
low_freq
<
vtln_low
<
high_freq
)
and
(
0.0
<
vtln_high
<
high_freq
)
and
(
vtln_low
<
vtln_high
)),
\
(
'Bad values in options: vtln-low
%f
and vtln-high
%f
, versus
low-freq %f and high-freq %f'
%
(
vtln_low
,
vtln_high
,
low_freq
,
high_freq
))
(
'Bad values in options: vtln-low
{}
and vtln-high
{}
, versus
'
'low-freq {} and high-freq {}'
.
format
(
vtln_low
,
vtln_high
,
low_freq
,
high_freq
))
bin
=
torch
.
arange
(
num_bins
).
unsqueeze
(
1
)
left_mel
=
mel_low_freq
+
bin
*
mel_freq_delta
# size(num_bins, 1)
...
...
torchaudio/functional.py
View file @
08a71271
...
...
@@ -149,8 +149,8 @@ def griffinlim(
Returns:
torch.Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
"""
assert
momentum
<
1
,
'momentum=
%s
> 1 can be unstable'
%
momentum
assert
momentum
>=
0
,
'momentum=
%s
< 0'
%
momentum
assert
momentum
<
1
,
'momentum=
{}
> 1 can be unstable'
.
format
(
momentum
)
assert
momentum
>=
0
,
'momentum=
{}
< 0'
.
format
(
momentum
)
# pack batch
shape
=
specgram
.
size
()
...
...
torchaudio/transforms.py
View file @
08a71271
...
...
@@ -141,8 +141,8 @@ class GriffinLim(torch.nn.Module):
rand_init
:
bool
=
True
)
->
None
:
super
(
GriffinLim
,
self
).
__init__
()
assert
momentum
<
1
,
'momentum=
%s
> 1 can be unstable'
%
momentum
assert
momentum
>
0
,
'momentum=
%s
< 0'
%
momentum
assert
momentum
<
1
,
'momentum=
{}
> 1 can be unstable'
.
format
(
momentum
)
assert
momentum
>
0
,
'momentum=
{}
< 0'
.
format
(
momentum
)
self
.
n_fft
=
n_fft
self
.
n_iter
=
n_iter
...
...
@@ -237,7 +237,7 @@ class MelScale(torch.nn.Module):
self
.
f_max
=
f_max
if
f_max
is
not
None
else
float
(
sample_rate
//
2
)
self
.
f_min
=
f_min
assert
f_min
<=
self
.
f_max
,
'Require f_min:
%f
< f_max:
%f'
%
(
f_min
,
self
.
f_max
)
assert
f_min
<=
self
.
f_max
,
'Require f_min:
{}
< f_max:
{}'
.
format
(
f_min
,
self
.
f_max
)
fb
=
torch
.
empty
(
0
)
if
n_stft
is
None
else
F
.
create_fb_matrix
(
n_stft
,
self
.
f_min
,
self
.
f_max
,
self
.
n_mels
,
self
.
sample_rate
)
...
...
@@ -313,7 +313,7 @@ class InverseMelScale(torch.nn.Module):
self
.
tolerance_change
=
tolerance_change
self
.
sgdargs
=
sgdargs
or
{
'lr'
:
0.1
,
'momentum'
:
0.9
}
assert
f_min
<=
self
.
f_max
,
'Require f_min:
%f
< f_max:
%f'
%
(
f_min
,
self
.
f_max
)
assert
f_min
<=
self
.
f_max
,
'Require f_min:
{}
< f_max:
{}'
.
format
(
f_min
,
self
.
f_max
)
fb
=
F
.
create_fb_matrix
(
n_stft
,
self
.
f_min
,
self
.
f_max
,
self
.
n_mels
,
self
.
sample_rate
)
self
.
register_buffer
(
'fb'
,
fb
)
...
...
@@ -607,7 +607,7 @@ class Resample(torch.nn.Module):
return
waveform
raise
ValueError
(
'Invalid resampling method:
%s'
%
(
self
.
resampling_method
))
raise
ValueError
(
'Invalid resampling method:
{}'
.
format
(
self
.
resampling_method
))
class
ComplexNorm
(
torch
.
nn
.
Module
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment