Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
b8203182
"...text-generation-inference.git" did not exist on "3dd7da21986d5efd8e9d1e935645109240e3efc7"
Unverified
Commit
b8203182
authored
Oct 21, 2019
by
Vincent QB
Committed by
GitHub
Oct 21, 2019
Browse files
standardizing n_* (#298)
parent
ce1f8aaf
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
46 additions
and
47 deletions
+46
-47
README.md
README.md
+3
-3
torchaudio/functional.py
torchaudio/functional.py
+40
-41
torchaudio/transforms.py
torchaudio/transforms.py
+3
-3
No files found.
README.md
View file @
b8203182
...
@@ -85,8 +85,8 @@ Quick Usage
...
@@ -85,8 +85,8 @@ Quick Usage
```
python
```
python
import
torchaudio
import
torchaudio
sound
,
sample_rate
=
torchaudio
.
load
(
'foo.mp3'
)
waveform
,
sample_rate
=
torchaudio
.
load
(
'foo.mp3'
)
# load tensor from file
torchaudio
.
save
(
'foo_save.mp3'
,
sound
,
sample_rate
)
# save
s
tensor to file
torchaudio
.
save
(
'foo_save.mp3'
,
waveform
,
sample_rate
)
# save tensor to file
```
```
API Reference
API Reference
...
@@ -118,7 +118,7 @@ dimension (channel, time)")
...
@@ -118,7 +118,7 @@ dimension (channel, time)")
*
`win_length`
: the length of the STFT window
*
`win_length`
: the length of the STFT window
*
`window_fn`
: for functions that creates windows e.g.
`torch.hann_window`
*
`window_fn`
: for functions that creates windows e.g.
`torch.hann_window`
Transforms expect the following dimensions.
Transforms expect
and return
the following dimensions.
*
`Spectrogram`
: (channel, time) -> (channel, freq, time)
*
`Spectrogram`
: (channel, time) -> (channel, freq, time)
*
`AmplitudeToDB`
: (channel, freq, time) -> (channel, freq, time)
*
`AmplitudeToDB`
: (channel, freq, time) -> (channel, freq, time)
...
...
torchaudio/functional.py
View file @
b8203182
...
@@ -83,7 +83,7 @@ def istft(
...
@@ -83,7 +83,7 @@ def istft(
Example: Suppose the last window is:
Example: Suppose the last window is:
[17, 18, 0, 0, 0] vs [18, 0, 0, 0, 0]
[17, 18, 0, 0, 0] vs [18, 0, 0, 0, 0]
The n_frame
s
, hop_length, win_length are all the same which prevents the calculation of right padding.
The n_frame, hop_length, win_length are all the same which prevents the calculation of right padding.
These additional values could be zeros or a reflection of the signal so providing ``length``
These additional values could be zeros or a reflection of the signal so providing ``length``
could be useful. If ``length`` is ``None`` then padding will be aggressively removed
could be useful. If ``length`` is ``None`` then padding will be aggressively removed
(some loss of signal).
(some loss of signal).
...
@@ -93,8 +93,8 @@ def istft(
...
@@ -93,8 +93,8 @@ def istft(
Args:
Args:
stft_matrix (torch.Tensor): Output of stft where each row of a channel is a frequency and each
stft_matrix (torch.Tensor): Output of stft where each row of a channel is a frequency and each
column is a window. it has a size of either (channel, fft_size, n_frame
s
, 2) or (
column is a window. it has a size of either (channel, fft_size, n_frame, 2) or (
fft_size, n_frame
s
, 2)
fft_size, n_frame, 2)
n_fft (int): Size of Fourier transform
n_fft (int): Size of Fourier transform
hop_length (Optional[int]): The distance between neighboring sliding window frames.
hop_length (Optional[int]): The distance between neighboring sliding window frames.
(Default: ``win_length // 4``)
(Default: ``win_length // 4``)
...
@@ -156,17 +156,17 @@ def istft(
...
@@ -156,17 +156,17 @@ def istft(
assert
window
.
size
(
0
)
==
n_fft
assert
window
.
size
(
0
)
==
n_fft
# win_length and n_fft are synonymous from here on
# win_length and n_fft are synonymous from here on
stft_matrix
=
stft_matrix
.
transpose
(
1
,
2
)
# size (channel, n_frame
s
, fft_size, 2)
stft_matrix
=
stft_matrix
.
transpose
(
1
,
2
)
# size (channel, n_frame, fft_size, 2)
stft_matrix
=
torch
.
irfft
(
stft_matrix
=
torch
.
irfft
(
stft_matrix
,
1
,
normalized
,
onesided
,
signal_sizes
=
(
n_fft
,)
stft_matrix
,
1
,
normalized
,
onesided
,
signal_sizes
=
(
n_fft
,)
)
# size (channel, n_frame
s
, n_fft)
)
# size (channel, n_frame, n_fft)
assert
stft_matrix
.
size
(
2
)
==
n_fft
assert
stft_matrix
.
size
(
2
)
==
n_fft
n_frame
s
=
stft_matrix
.
size
(
1
)
n_frame
=
stft_matrix
.
size
(
1
)
ytmp
=
stft_matrix
*
window
.
view
(
1
,
1
,
n_fft
)
# size (channel, n_frame
s
, n_fft)
ytmp
=
stft_matrix
*
window
.
view
(
1
,
1
,
n_fft
)
# size (channel, n_frame, n_fft)
# each column of a channel is a frame which needs to be overlap added at the right place
# each column of a channel is a frame which needs to be overlap added at the right place
ytmp
=
ytmp
.
transpose
(
1
,
2
)
# size (channel, n_fft, n_frame
s
)
ytmp
=
ytmp
.
transpose
(
1
,
2
)
# size (channel, n_fft, n_frame)
eye
=
torch
.
eye
(
n_fft
,
requires_grad
=
False
,
device
=
device
,
dtype
=
dtype
).
unsqueeze
(
eye
=
torch
.
eye
(
n_fft
,
requires_grad
=
False
,
device
=
device
,
dtype
=
dtype
).
unsqueeze
(
1
1
...
@@ -180,13 +180,13 @@ def istft(
...
@@ -180,13 +180,13 @@ def istft(
# do the same for the window function
# do the same for the window function
window_sq
=
(
window_sq
=
(
window
.
pow
(
2
).
view
(
n_fft
,
1
).
repeat
((
1
,
n_frame
s
)).
unsqueeze
(
0
)
window
.
pow
(
2
).
view
(
n_fft
,
1
).
repeat
((
1
,
n_frame
)).
unsqueeze
(
0
)
)
# size (1, n_fft, n_frame
s
)
)
# size (1, n_fft, n_frame)
window_envelop
=
torch
.
nn
.
functional
.
conv_transpose1d
(
window_envelop
=
torch
.
nn
.
functional
.
conv_transpose1d
(
window_sq
,
eye
,
stride
=
hop_length
,
padding
=
0
window_sq
,
eye
,
stride
=
hop_length
,
padding
=
0
)
# size (1, 1, expected_signal_len)
)
# size (1, 1, expected_signal_len)
expected_signal_len
=
n_fft
+
hop_length
*
(
n_frame
s
-
1
)
expected_signal_len
=
n_fft
+
hop_length
*
(
n_frame
-
1
)
assert
y
.
size
(
2
)
==
expected_signal_len
assert
y
.
size
(
2
)
==
expected_signal_len
assert
window_envelop
.
size
(
2
)
==
expected_signal_len
assert
window_envelop
.
size
(
2
)
==
expected_signal_len
...
@@ -233,9 +233,9 @@ def spectrogram(
...
@@ -233,9 +233,9 @@ def spectrogram(
normalized (bool): Whether to normalize by magnitude after stft
normalized (bool): Whether to normalize by magnitude after stft
Returns:
Returns:
torch.Tensor: Dimension (channel,
n_
freq, time), where channel
torch.Tensor: Dimension (channel, freq, time), where channel
is unchanged,
n_
freq is ``n_fft // 2 + 1``
where
``n_fft`` is the number of
is unchanged, freq is ``n_fft // 2 + 1``
and
``n_fft`` is the number of
Fourier bins, and time is the number of window hops (n_frame
s
).
Fourier bins, and time is the number of window hops (n_frame).
"""
"""
assert
waveform
.
dim
()
==
2
assert
waveform
.
dim
()
==
2
...
@@ -537,7 +537,7 @@ def lfilter(waveform, a_coeffs, b_coeffs):
...
@@ -537,7 +537,7 @@ def lfilter(waveform, a_coeffs, b_coeffs):
Performs an IIR filter by evaluating difference equation.
Performs an IIR filter by evaluating difference equation.
Args:
Args:
waveform (torch.Tensor): audio waveform of dimension of `(
n_
channel,
n_fra
me
s
)`. Must be normalized to -1 to 1.
waveform (torch.Tensor): audio waveform of dimension of `(channel,
ti
me)`. Must be normalized to -1 to 1.
a_coeffs (torch.Tensor): denominator coefficients of difference equation of dimension of `(n_order + 1)`.
a_coeffs (torch.Tensor): denominator coefficients of difference equation of dimension of `(n_order + 1)`.
Lower delays coefficients are first, e.g. `[a0, a1, a2, ...]`.
Lower delays coefficients are first, e.g. `[a0, a1, a2, ...]`.
Must be same size as b_coeffs (pad with 0's as necessary).
Must be same size as b_coeffs (pad with 0's as necessary).
...
@@ -546,8 +546,7 @@ def lfilter(waveform, a_coeffs, b_coeffs):
...
@@ -546,8 +546,7 @@ def lfilter(waveform, a_coeffs, b_coeffs):
Must be same size as a_coeffs (pad with 0's as necessary).
Must be same size as a_coeffs (pad with 0's as necessary).
Returns:
Returns:
output_waveform (torch.Tensor): Dimension of `(n_channel, n_frames)`. Output will be clipped to -1 to 1.
output_waveform (torch.Tensor): Dimension of `(channel, time)`. Output will be clipped to -1 to 1.
Will be on the same device as the inputs.
"""
"""
...
@@ -558,37 +557,37 @@ def lfilter(waveform, a_coeffs, b_coeffs):
...
@@ -558,37 +557,37 @@ def lfilter(waveform, a_coeffs, b_coeffs):
device
=
waveform
.
device
device
=
waveform
.
device
dtype
=
waveform
.
dtype
dtype
=
waveform
.
dtype
n_channel
s
,
n_
frames
=
waveform
.
size
()
n_channel
,
n_
sample
=
waveform
.
size
()
n_order
=
a_coeffs
.
size
(
0
)
n_order
=
a_coeffs
.
size
(
0
)
assert
(
n_order
>
0
)
assert
(
n_order
>
0
)
# Pad the input and create output
# Pad the input and create output
padded_waveform
=
torch
.
zeros
(
n_channel
s
,
n_
frames
+
n_order
-
1
,
dtype
=
dtype
,
device
=
device
)
padded_waveform
=
torch
.
zeros
(
n_channel
,
n_
sample
+
n_order
-
1
,
dtype
=
dtype
,
device
=
device
)
padded_waveform
[:,
(
n_order
-
1
):]
=
waveform
padded_waveform
[:,
(
n_order
-
1
):]
=
waveform
padded_output_waveform
=
torch
.
zeros
(
n_channel
s
,
n_
frames
+
n_order
-
1
,
dtype
=
dtype
,
device
=
device
)
padded_output_waveform
=
torch
.
zeros
(
n_channel
,
n_
sample
+
n_order
-
1
,
dtype
=
dtype
,
device
=
device
)
# Set up the coefficients matrix
# Set up the coefficients matrix
# Flip order, repeat, and transpose
# Flip order, repeat, and transpose
a_coeffs_filled
=
a_coeffs
.
flip
(
0
).
repeat
(
n_channel
s
,
1
).
t
()
a_coeffs_filled
=
a_coeffs
.
flip
(
0
).
repeat
(
n_channel
,
1
).
t
()
b_coeffs_filled
=
b_coeffs
.
flip
(
0
).
repeat
(
n_channel
s
,
1
).
t
()
b_coeffs_filled
=
b_coeffs
.
flip
(
0
).
repeat
(
n_channel
,
1
).
t
()
# Set up a few other utilities
# Set up a few other utilities
a0_repeated
=
torch
.
ones
(
n_channel
s
,
dtype
=
dtype
,
device
=
device
)
*
a_coeffs
[
0
]
a0_repeated
=
torch
.
ones
(
n_channel
,
dtype
=
dtype
,
device
=
device
)
*
a_coeffs
[
0
]
ones
=
torch
.
ones
(
n_channel
s
,
n_
frames
,
dtype
=
dtype
,
device
=
device
)
ones
=
torch
.
ones
(
n_channel
,
n_
sample
,
dtype
=
dtype
,
device
=
device
)
for
i_
fram
e
in
range
(
n_
frames
):
for
i_
sampl
e
in
range
(
n_
sample
):
o0
=
torch
.
zeros
(
n_channel
s
,
dtype
=
dtype
,
device
=
device
)
o0
=
torch
.
zeros
(
n_channel
,
dtype
=
dtype
,
device
=
device
)
windowed_input_signal
=
padded_waveform
[:,
i_
frame
:(
i_fram
e
+
n_order
)]
windowed_input_signal
=
padded_waveform
[:,
i_
sample
:(
i_sampl
e
+
n_order
)]
windowed_output_signal
=
padded_output_waveform
[:,
i_
frame
:(
i_fram
e
+
n_order
)]
windowed_output_signal
=
padded_output_waveform
[:,
i_
sample
:(
i_sampl
e
+
n_order
)]
o0
.
add_
(
torch
.
diag
(
torch
.
mm
(
windowed_input_signal
,
b_coeffs_filled
)))
o0
.
add_
(
torch
.
diag
(
torch
.
mm
(
windowed_input_signal
,
b_coeffs_filled
)))
o0
.
sub_
(
torch
.
diag
(
torch
.
mm
(
windowed_output_signal
,
a_coeffs_filled
)))
o0
.
sub_
(
torch
.
diag
(
torch
.
mm
(
windowed_output_signal
,
a_coeffs_filled
)))
o0
.
div_
(
a0_repeated
)
o0
.
div_
(
a0_repeated
)
padded_output_waveform
[:,
i_
fram
e
+
n_order
-
1
]
=
o0
padded_output_waveform
[:,
i_
sampl
e
+
n_order
-
1
]
=
o0
return
torch
.
min
(
ones
,
torch
.
max
(
ones
*
-
1
,
padded_output_waveform
[:,
(
n_order
-
1
):]))
return
torch
.
min
(
ones
,
torch
.
max
(
ones
*
-
1
,
padded_output_waveform
[:,
(
n_order
-
1
):]))
...
@@ -599,7 +598,7 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2):
...
@@ -599,7 +598,7 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2):
https://en.wikipedia.org/wiki/Digital_biquad_filter
https://en.wikipedia.org/wiki/Digital_biquad_filter
Args:
Args:
waveform (torch.Tensor): audio waveform of dimension of `(
n_
channel,
n_fra
me
s
)`
waveform (torch.Tensor): audio waveform of dimension of `(channel,
ti
me)`
b0 (float): numerator coefficient of current input, x[n]
b0 (float): numerator coefficient of current input, x[n]
b1 (float): numerator coefficient of input one time step ago x[n-1]
b1 (float): numerator coefficient of input one time step ago x[n-1]
b2 (float): numerator coefficient of input two time steps ago x[n-2]
b2 (float): numerator coefficient of input two time steps ago x[n-2]
...
@@ -608,7 +607,7 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2):
...
@@ -608,7 +607,7 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2):
a2 (float): denominator coefficient of current output y[n-2]
a2 (float): denominator coefficient of current output y[n-2]
Returns:
Returns:
output_waveform (torch.Tensor): Dimension of `(
n_
channel,
n_fra
me
s
)`
output_waveform (torch.Tensor): Dimension of `(channel,
ti
me)`
"""
"""
device
=
waveform
.
device
device
=
waveform
.
device
...
@@ -631,13 +630,13 @@ def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
...
@@ -631,13 +630,13 @@ def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
r
"""Designs biquad highpass filter and performs filtering. Similar to SoX implementation.
r
"""Designs biquad highpass filter and performs filtering. Similar to SoX implementation.
Args:
Args:
waveform (torch.Tensor): audio waveform of dimension of `(
n_
channel,
n_fra
me
s
)`
waveform (torch.Tensor): audio waveform of dimension of `(channel,
ti
me)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
cutoff_freq (float): filter cutoff frequency
cutoff_freq (float): filter cutoff frequency
Q (float): https://en.wikipedia.org/wiki/Q_factor
Q (float): https://en.wikipedia.org/wiki/Q_factor
Returns:
Returns:
output_waveform (torch.Tensor): Dimension of `(
n_
channel,
n_fra
me
s
)`
output_waveform (torch.Tensor): Dimension of `(channel,
ti
me)`
"""
"""
GAIN
=
1
# TBD - add as a parameter
GAIN
=
1
# TBD - add as a parameter
...
@@ -660,13 +659,13 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
...
@@ -660,13 +659,13 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
r
"""Designs biquad lowpass filter and performs filtering. Similar to SoX implementation.
r
"""Designs biquad lowpass filter and performs filtering. Similar to SoX implementation.
Args:
Args:
waveform (torch.Tensor): audio waveform of dimension of `(
n_
channel,
n_fra
me
s
)`
waveform (torch.Tensor): audio waveform of dimension of `(channel,
ti
me)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
cutoff_freq (float): filter cutoff frequency
cutoff_freq (float): filter cutoff frequency
Q (float): https://en.wikipedia.org/wiki/Q_factor
Q (float): https://en.wikipedia.org/wiki/Q_factor
Returns:
Returns:
output_waveform (torch.Tensor): Dimension of `(
n_
channel,
n_fra
me
s
)`
output_waveform (torch.Tensor): Dimension of `(channel,
ti
me)`
"""
"""
GAIN
=
1
GAIN
=
1
...
@@ -693,13 +692,13 @@ def mask_along_axis_iid(specgrams, mask_param, mask_value, axis):
...
@@ -693,13 +692,13 @@ def mask_along_axis_iid(specgrams, mask_param, mask_value, axis):
All examples will have the same mask interval.
All examples will have the same mask interval.
Args:
Args:
specgrams (Tensor): Real spectrograms (batch, channel,
n_
freq, time)
specgrams (Tensor): Real spectrograms (batch, channel, freq, time)
mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
mask_value (float): Value to assign to the masked columns
mask_value (float): Value to assign to the masked columns
axis (int): Axis to apply masking on (2 -> frequency, 3 -> time)
axis (int): Axis to apply masking on (2 -> frequency, 3 -> time)
Returns:
Returns:
torch.Tensor: Masked spectrograms of dimensions (batch, channel,
n_
freq, time)
torch.Tensor: Masked spectrograms of dimensions (batch, channel, freq, time)
"""
"""
if
axis
!=
2
and
axis
!=
3
:
if
axis
!=
2
and
axis
!=
3
:
...
@@ -730,13 +729,13 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
...
@@ -730,13 +729,13 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
All examples will have the same mask interval.
All examples will have the same mask interval.
Args:
Args:
specgram (Tensor): Real spectrogram (channel,
n_
freq, time)
specgram (Tensor): Real spectrogram (channel, freq, time)
mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
mask_value (float): Value to assign to the masked columns
mask_value (float): Value to assign to the masked columns
axis (int): Axis to apply masking on (1 -> frequency, 2 -> time)
axis (int): Axis to apply masking on (1 -> frequency, 2 -> time)
Returns:
Returns:
torch.Tensor: Masked spectrogram of dimensions (channel,
n_
freq, time)
torch.Tensor: Masked spectrogram of dimensions (channel, freq, time)
"""
"""
value
=
torch
.
rand
(
1
)
*
mask_param
value
=
torch
.
rand
(
1
)
*
mask_param
...
@@ -769,12 +768,12 @@ def compute_deltas(specgram, win_length=5, mode="replicate"):
...
@@ -769,12 +768,12 @@ def compute_deltas(specgram, win_length=5, mode="replicate"):
:math:`N` is (`win_length`-1)//2.
:math:`N` is (`win_length`-1)//2.
Args:
Args:
specgram (torch.Tensor): Tensor of audio of dimension (channel,
n_mfcc
, time)
specgram (torch.Tensor): Tensor of audio of dimension (channel,
freq
, time)
win_length (int): The window length used for computing delta
win_length (int): The window length used for computing delta
mode (str): Mode parameter passed to padding
mode (str): Mode parameter passed to padding
Returns:
Returns:
deltas (torch.Tensor): Tensor of audio of dimension (channel,
n_mfcc
, time)
deltas (torch.Tensor): Tensor of audio of dimension (channel,
freq
, time)
Example
Example
>>> specgram = torch.randn(1, 40, 1000)
>>> specgram = torch.randn(1, 40, 1000)
...
...
torchaudio/transforms.py
View file @
b8203182
...
@@ -62,7 +62,7 @@ class Spectrogram(torch.jit.ScriptModule):
...
@@ -62,7 +62,7 @@ class Spectrogram(torch.jit.ScriptModule):
Returns:
Returns:
torch.Tensor: Dimension (channel, freq, time), where channel
torch.Tensor: Dimension (channel, freq, time), where channel
is unchanged, freq is ``n_fft // 2 + 1`` where ``n_fft`` is the number of
is unchanged, freq is ``n_fft // 2 + 1`` where ``n_fft`` is the number of
Fourier bins, and time is the number of window hops (n_frame
s
).
Fourier bins, and time is the number of window hops (n_frame).
"""
"""
return
F
.
spectrogram
(
waveform
,
self
.
pad
,
self
.
window
,
self
.
n_fft
,
self
.
hop_length
,
return
F
.
spectrogram
(
waveform
,
self
.
pad
,
self
.
window
,
self
.
n_fft
,
self
.
hop_length
,
self
.
win_length
,
self
.
power
,
self
.
normalized
)
self
.
win_length
,
self
.
power
,
self
.
normalized
)
...
@@ -409,9 +409,9 @@ class ComputeDeltas(torch.jit.ScriptModule):
...
@@ -409,9 +409,9 @@ class ComputeDeltas(torch.jit.ScriptModule):
def
forward
(
self
,
specgram
):
def
forward
(
self
,
specgram
):
r
"""
r
"""
Args:
Args:
specgram (torch.Tensor): Tensor of audio of dimension (channel,
n_mfcc
, time)
specgram (torch.Tensor): Tensor of audio of dimension (channel,
freq
, time)
Returns:
Returns:
deltas (torch.Tensor): Tensor of audio of dimension (channel,
n_mfcc
, time)
deltas (torch.Tensor): Tensor of audio of dimension (channel,
freq
, time)
"""
"""
return
F
.
compute_deltas
(
specgram
,
win_length
=
self
.
win_length
,
mode
=
self
.
mode
)
return
F
.
compute_deltas
(
specgram
,
win_length
=
self
.
win_length
,
mode
=
self
.
mode
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment