Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
ffeba11a
"vscode:/vscode.git/clone" did not exist on "668ecc6c5b375249578b83dbabfb47c3ed5d9dbd"
Commit
ffeba11a
authored
Sep 02, 2024
by
mayp777
Browse files
UPDATE
parent
29deb085
Changes
337
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1847 additions
and
137 deletions
+1847
-137
test/torchaudio_unittest/functional/librosa_compatibility_cuda_test.py
...io_unittest/functional/librosa_compatibility_cuda_test.py
+3
-1
test/torchaudio_unittest/functional/librosa_compatibility_test_impl.py
...io_unittest/functional/librosa_compatibility_test_impl.py
+2
-2
test/torchaudio_unittest/functional/torchscript_consistency_cuda_test.py
..._unittest/functional/torchscript_consistency_cuda_test.py
+2
-1
test/torchaudio_unittest/functional/torchscript_consistency_impl.py
...audio_unittest/functional/torchscript_consistency_impl.py
+52
-14
test/torchaudio_unittest/io/common.py
test/torchaudio_unittest/io/common.py
+16
-0
test/torchaudio_unittest/io/effector_test.py
test/torchaudio_unittest/io/effector_test.py
+102
-0
test/torchaudio_unittest/io/playback_test.py
test/torchaudio_unittest/io/playback_test.py
+65
-0
test/torchaudio_unittest/io/stream_reader_test.py
test/torchaudio_unittest/io/stream_reader_test.py
+645
-33
test/torchaudio_unittest/io/stream_writer_test.py
test/torchaudio_unittest/io/stream_writer_test.py
+449
-42
test/torchaudio_unittest/models/decoder/ctc_decoder_test.py
test/torchaudio_unittest/models/decoder/ctc_decoder_test.py
+16
-0
test/torchaudio_unittest/models/decoder/cuda_ctc_decoder_test.py
...rchaudio_unittest/models/decoder/cuda_ctc_decoder_test.py
+49
-0
test/torchaudio_unittest/models/rnnt_decoder/rnnt_decoder_test_impl.py
...io_unittest/models/rnnt_decoder/rnnt_decoder_test_impl.py
+2
-2
test/torchaudio_unittest/models/squim/__init__.py
test/torchaudio_unittest/models/squim/__init__.py
+0
-0
test/torchaudio_unittest/models/squim/squim_test.py
test/torchaudio_unittest/models/squim/squim_test.py
+113
-0
test/torchaudio_unittest/models/tacotron2/model_test_impl.py
test/torchaudio_unittest/models/tacotron2/model_test_impl.py
+3
-3
test/torchaudio_unittest/models/wav2vec2/fairseq_integration_test.py
...udio_unittest/models/wav2vec2/fairseq_integration_test.py
+44
-6
test/torchaudio_unittest/models/wav2vec2/huggingface_intergration_test.py
...unittest/models/wav2vec2/huggingface_intergration_test.py
+143
-10
test/torchaudio_unittest/models/wav2vec2/model_test.py
test/torchaudio_unittest/models/wav2vec2/model_test.py
+135
-0
test/torchaudio_unittest/sox_effect/dataset_test.py
test/torchaudio_unittest/sox_effect/dataset_test.py
+6
-2
test/torchaudio_unittest/sox_effect/smoke_test.py
test/torchaudio_unittest/sox_effect/smoke_test.py
+0
-21
No files found.
Too many changes to show.
To preserve performance only
337 of 337+
files are displayed.
Plain diff
Email patch
test/torchaudio_unittest/functional/librosa_compatibility_cuda_test.py
View file @
ffeba11a
from
torchaudio_unittest.common_utils
import
PytorchTestCase
,
skipIfNoCuda
from
torchaudio_unittest.common_utils
import
PytorchTestCase
,
skipIfNoCuda
,
skipIfkmeMark
from
.librosa_compatibility_test_impl
import
Functional
,
FunctionalComplex
@
skipIfNoCuda
@
skipIfkmeMark
class
TestFunctionalCUDA
(
Functional
,
PytorchTestCase
):
device
=
"cuda"
@
skipIfNoCuda
@
skipIfkmeMark
class
TestFunctionalComplexCUDA
(
FunctionalComplex
,
PytorchTestCase
):
device
=
"cuda"
test/torchaudio_unittest/functional/librosa_compatibility_test_impl.py
View file @
ffeba11a
import
unittest
from
distutils.version
import
Strict
Version
from
distutils.version
import
Loose
Version
import
torch
import
torchaudio.functional
as
F
...
...
@@ -77,7 +77,7 @@ class Functional(TestBaseMixin):
def
test_create_mel_fb
(
self
,
n_mels
=
40
,
sample_rate
=
22050
,
n_fft
=
2048
,
fmin
=
0.0
,
fmax
=
8000.0
,
norm
=
None
,
mel_scale
=
"htk"
):
if
norm
==
"slaney"
and
Strict
Version
(
librosa
.
__version__
)
<
Strict
Version
(
"0.7.2"
):
if
norm
==
"slaney"
and
Loose
Version
(
librosa
.
__version__
)
<
Loose
Version
(
"0.7.2"
):
self
.
skipTest
(
"Test is known to fail with older versions of librosa."
)
if
self
.
device
!=
"cpu"
:
self
.
skipTest
(
"No need to run this test on CUDA"
)
...
...
test/torchaudio_unittest/functional/torchscript_consistency_cuda_test.py
View file @
ffeba11a
import
torch
from
torchaudio_unittest.common_utils
import
PytorchTestCase
,
skipIfNoCuda
from
torchaudio_unittest.common_utils
import
PytorchTestCase
,
skipIfNoCuda
,
skipIfkmeMark
from
.torchscript_consistency_impl
import
Functional
,
FunctionalFloat32Only
...
...
@@ -11,6 +11,7 @@ class TestFunctionalFloat32(Functional, FunctionalFloat32Only, PytorchTestCase):
@
skipIfNoCuda
@
skipIfkmeMark
class
TestFunctionalFloat64
(
Functional
,
PytorchTestCase
):
dtype
=
torch
.
float64
device
=
torch
.
device
(
"cuda"
)
test/torchaudio_unittest/functional/torchscript_consistency_impl.py
View file @
ffeba11a
...
...
@@ -585,22 +585,10 @@ class Functional(TempDirMixin, TestBaseMixin):
tensor
=
common_utils
.
get_whitenoise
(
sample_rate
=
44100
)
self
.
_assert_consistency
(
func
,
(
tensor
,))
@
common_utils
.
skipIfNoKaldi
def
test_compute_kaldi_pitch
(
self
):
if
self
.
dtype
!=
torch
.
float32
or
self
.
device
!=
torch
.
device
(
"cpu"
):
raise
unittest
.
SkipTest
(
"Only float32, cpu is supported."
)
def
func
(
tensor
):
sample_rate
:
float
=
44100.0
return
F
.
compute_kaldi_pitch
(
tensor
,
sample_rate
)
tensor
=
common_utils
.
get_whitenoise
(
sample_rate
=
44100
)
self
.
_assert_consistency
(
func
,
(
tensor
,))
def
test_resample_sinc
(
self
):
def
func
(
tensor
):
sr1
,
sr2
=
16000
,
8000
return
F
.
resample
(
tensor
,
sr1
,
sr2
,
resampling_method
=
"sinc_interp
olatio
n"
)
return
F
.
resample
(
tensor
,
sr1
,
sr2
,
resampling_method
=
"sinc_interp
_han
n"
)
tensor
=
common_utils
.
get_whitenoise
(
sample_rate
=
16000
)
self
.
_assert_consistency
(
func
,
(
tensor
,))
...
...
@@ -616,7 +604,9 @@ class Functional(TempDirMixin, TestBaseMixin):
sr1
,
sr2
=
16000
,
8000
lowpass_filter_width
=
6
rolloff
=
0.99
self
.
_assert_consistency
(
F
.
resample
,
(
tensor
,
sr1
,
sr2
,
lowpass_filter_width
,
rolloff
,
"kaiser_window"
,
beta
))
self
.
_assert_consistency
(
F
.
resample
,
(
tensor
,
sr1
,
sr2
,
lowpass_filter_width
,
rolloff
,
"sinc_interp_kaiser"
,
beta
)
)
def
test_phase_vocoder
(
self
):
tensor
=
torch
.
view_as_complex
(
torch
.
randn
(
2
,
1025
,
400
,
2
))
...
...
@@ -756,6 +746,54 @@ class Functional(TempDirMixin, TestBaseMixin):
specgram
=
torch
.
rand
(
num_channels
,
n_fft_bin
,
num_frames
,
dtype
=
self
.
complex_dtype
,
device
=
self
.
device
)
self
.
_assert_consistency_complex
(
F
.
apply_beamforming
,
(
beamform_weights
,
specgram
))
@
common_utils
.
nested_params
(
[
"convolve"
,
"fftconvolve"
],
[
"full"
,
"valid"
,
"same"
],
)
def
test_convolve
(
self
,
fn
,
mode
):
leading_dims
=
(
2
,
3
,
2
)
L_x
,
L_y
=
32
,
55
x
=
torch
.
rand
(
*
leading_dims
,
L_x
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
y
=
torch
.
rand
(
*
leading_dims
,
L_y
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
self
.
_assert_consistency
(
getattr
(
F
,
fn
),
(
x
,
y
,
mode
))
@
common_utils
.
nested_params
([
True
,
False
])
def
test_add_noise
(
self
,
use_lengths
):
leading_dims
=
(
2
,
3
)
L
=
31
waveform
=
torch
.
rand
(
*
leading_dims
,
L
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
requires_grad
=
True
)
noise
=
torch
.
rand
(
*
leading_dims
,
L
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
requires_grad
=
True
)
if
use_lengths
:
lengths
=
torch
.
rand
(
*
leading_dims
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
requires_grad
=
True
)
else
:
lengths
=
None
snr
=
torch
.
rand
(
*
leading_dims
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
requires_grad
=
True
)
*
10
self
.
_assert_consistency
(
F
.
add_noise
,
(
waveform
,
noise
,
snr
,
lengths
))
@
common_utils
.
nested_params
([
True
,
False
])
def
test_speed
(
self
,
use_lengths
):
leading_dims
=
(
3
,
2
)
T
=
200
waveform
=
torch
.
rand
(
*
leading_dims
,
T
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
requires_grad
=
True
)
if
use_lengths
:
lengths
=
torch
.
randint
(
1
,
T
,
leading_dims
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
else
:
lengths
=
None
self
.
_assert_consistency
(
F
.
speed
,
(
waveform
,
1000
,
1.1
,
lengths
))
def
test_preemphasis
(
self
):
waveform
=
torch
.
rand
(
3
,
2
,
100
,
device
=
self
.
device
,
dtype
=
self
.
dtype
)
coeff
=
0.9
self
.
_assert_consistency
(
F
.
preemphasis
,
(
waveform
,
coeff
))
def
test_deemphasis
(
self
):
waveform
=
torch
.
rand
(
3
,
2
,
100
,
device
=
self
.
device
,
dtype
=
self
.
dtype
)
coeff
=
0.9
self
.
_assert_consistency
(
F
.
deemphasis
,
(
waveform
,
coeff
))
class
FunctionalFloat32Only
(
TestBaseMixin
):
def
test_rnnt_loss
(
self
):
...
...
test/torchaudio_unittest/io/common.py
0 → 100644
View file @
ffeba11a
import
torchaudio
# If FFmpeg is 4.1 or older
# Tests that checks the number of output samples from OPUS fails
# They work on 4.2+
# Probably this commit fixed it.
# https://github.com/FFmpeg/FFmpeg/commit/18aea7bdd96b320a40573bccabea56afeccdd91c
def
lt42
():
ver
=
torchaudio
.
utils
.
ffmpeg_utils
.
get_versions
()[
"libavcodec"
]
# 5.1 libavcodec 59. 18.100
# 4.4 libavcodec 58.134.100
# 4.3 libavcodec 58. 91.100
# 4.2 libavcodec 58. 54.100
# 4.1 libavcodec 58. 35.100
return
ver
[
0
]
<
59
and
ver
[
1
]
<
54
test/torchaudio_unittest/io/effector_test.py
0 → 100644
View file @
ffeba11a
from
parameterized
import
parameterized
from
torchaudio.io
import
AudioEffector
from
torchaudio_unittest.common_utils
import
get_sinusoid
,
skipIfNoFFmpeg
,
TorchaudioTestCase
from
.common
import
lt42
@
skipIfNoFFmpeg
class
EffectorTest
(
TorchaudioTestCase
):
def
test_null
(
self
):
"""No effect and codec will return the same result"""
sample_rate
=
8000
frames_per_chunk
=
256
effector
=
AudioEffector
(
effect
=
None
,
format
=
None
)
original
=
get_sinusoid
(
n_channels
=
3
,
sample_rate
=
sample_rate
,
channels_first
=
False
)
# one-go
output
=
effector
.
apply
(
original
,
sample_rate
)
self
.
assertEqual
(
original
,
output
)
# streaming
for
i
,
chunk
in
enumerate
(
effector
.
stream
(
original
,
sample_rate
,
frames_per_chunk
)):
start
=
i
*
frames_per_chunk
end
=
(
i
+
1
)
*
frames_per_chunk
self
.
assertEqual
(
original
[
start
:
end
,
:],
chunk
)
@
parameterized
.
expand
(
[
(
"ogg"
,
"flac"
),
# flac only supports s16 and s32
(
"ogg"
,
"opus"
),
# opus only supports 48k Hz
(
"ogg"
,
"vorbis"
),
# vorbis only supports stereo
# ("ogg", "vorbis", 44100),
# this fails with small descrepancy; 441024 vs 441000
# TODO: investigate
(
"wav"
,
None
),
(
"wav"
,
"pcm_u8"
),
(
"mp3"
,
None
),
(
"mulaw"
,
None
,
44100
),
# mulaw is encoded without header
]
)
def
test_formats
(
self
,
format
,
encoder
,
sample_rate
=
8000
):
"""Formats (some with restrictions) just work without an issue in effector"""
effector
=
AudioEffector
(
format
=
format
,
encoder
=
encoder
)
original
=
get_sinusoid
(
n_channels
=
3
,
sample_rate
=
sample_rate
,
channels_first
=
False
)
output
=
effector
.
apply
(
original
,
sample_rate
)
# On 4.1 OPUS produces 8020 samples (extra 20)
# this has been fixed on 4.2+
if
encoder
==
"opus"
and
lt42
():
return
self
.
assertEqual
(
original
.
shape
,
output
.
shape
)
# Note
# MP3 adds padding which cannot be removed when the encoded data is written to
# file-like object without seek method.
# The number of padding is retrievable as `AVCoedcContext::initial_padding`
# https://ffmpeg.org/doxygen/4.1/structAVCodecContext.html#a8f95550ce04f236e9915516d04d3d1ab
# but this is not exposed yet.
# These "priming" samples have negative time stamp, so we can also add logic
# to discard them at decoding, however, as far as I checked, when data is loaded
# with StreamReader, the time stamp is reset. I tried options like avoid_negative_ts,
# https://ffmpeg.org/ffmpeg-formats.html
# but it made no difference. Perhaps this is because the information about negative
# timestamp is only available at encoding side, and it presumably is written to
# header file, but it is not happening somehow with file-like object.
# Need to investigate more to remove MP3 padding
if
format
==
"mp3"
:
return
for
chunk
in
effector
.
stream
(
original
,
sample_rate
,
frames_per_chunk
=
original
.
size
(
0
)):
self
.
assertEqual
(
original
.
shape
,
chunk
.
shape
)
@
parameterized
.
expand
([(
"loudnorm=I=-16:LRA=11:TP=-1.5"
,),
(
"volume=2"
,)])
def
test_effect
(
self
,
effect
):
sample_rate
=
8000
effector
=
AudioEffector
(
effect
=
effect
)
original
=
get_sinusoid
(
n_channels
=
3
,
sample_rate
=
sample_rate
,
channels_first
=
False
)
output
=
effector
.
apply
(
original
,
sample_rate
)
self
.
assertEqual
(
original
.
shape
,
output
.
shape
)
def
test_resample
(
self
):
"""Resample option allows to change the sampling rate"""
sample_rate
=
8000
output_sample_rate
=
16000
num_channels
=
3
effector
=
AudioEffector
(
effect
=
"lowpass"
)
original
=
get_sinusoid
(
n_channels
=
num_channels
,
sample_rate
=
sample_rate
,
channels_first
=
False
)
output
=
effector
.
apply
(
original
,
sample_rate
,
output_sample_rate
)
self
.
assertEqual
(
output
.
shape
,
[
output_sample_rate
,
num_channels
])
for
chunk
in
effector
.
stream
(
original
,
sample_rate
,
output_sample_rate
=
output_sample_rate
,
frames_per_chunk
=
output_sample_rate
):
self
.
assertEqual
(
chunk
.
shape
,
[
output_sample_rate
,
num_channels
])
test/torchaudio_unittest/io/playback_test.py
0 → 100644
View file @
ffeba11a
from
unittest.mock
import
patch
import
torch
from
parameterized
import
parameterized
from
torchaudio.io
import
play_audio
,
StreamWriter
from
torchaudio_unittest.common_utils
import
get_sinusoid
,
skipIfNoAudioDevice
,
skipIfNoMacOS
,
TorchaudioTestCase
@
skipIfNoAudioDevice
@
skipIfNoMacOS
class
PlaybackInterfaceTest
(
TorchaudioTestCase
):
@
parameterized
.
expand
([(
"uint8"
,),
(
"int16"
,),
(
"int32"
,),
(
"int64"
,),
(
"float32"
,),
(
"float64"
,)])
@
patch
.
object
(
StreamWriter
,
"write_audio_chunk"
)
def
test_playaudio
(
self
,
dtype
,
writeaudio_mock
):
"""Test playaudio function.
The patch object is used to check if the data is written
to the output device stream, without playing the actual audio.
"""
dtype
=
getattr
(
torch
,
dtype
)
sample_rate
=
8000
waveform
=
get_sinusoid
(
frequency
=
440
,
sample_rate
=
sample_rate
,
duration
=
1
,
# seconds
n_channels
=
1
,
dtype
=
dtype
,
device
=
"cpu"
,
channels_first
=
False
,
)
play_audio
(
waveform
,
sample_rate
=
sample_rate
)
writeaudio_mock
.
assert_called
()
@
parameterized
.
expand
(
[
# Invalid number of dimensions (!= 2)
(
"int16"
,
1
,
"audiotoolbox"
),
(
"int16"
,
3
,
"audiotoolbox"
),
# Invalid tensor type
(
"complex64"
,
2
,
"audiotoolbox"
),
# Invalid output device
(
"int16"
,
2
,
"audiotool"
),
]
)
@
patch
.
object
(
StreamWriter
,
"write_audio_chunk"
)
def
test_playaudio_invalid_options
(
self
,
dtype
,
ndim
,
device
,
writeaudio_mock
):
"""Test playaudio function raises error with invalid options."""
dtype
=
getattr
(
torch
,
dtype
)
sample_rate
=
8000
waveform
=
get_sinusoid
(
frequency
=
440
,
sample_rate
=
sample_rate
,
duration
=
1
,
# seconds
n_channels
=
1
,
dtype
=
dtype
,
device
=
"cpu"
,
channels_first
=
False
,
).
squeeze
()
for
_
in
range
(
ndim
-
1
):
waveform
=
waveform
.
unsqueeze
(
-
1
)
with
self
.
assertRaises
(
ValueError
):
play_audio
(
waveform
,
sample_rate
=
sample_rate
,
device
=
device
)
test/torchaudio_unittest/io/stream_reader_test.py
View file @
ffeba11a
import
io
import
torch
import
torchaudio
from
parameterized
import
parameterized
,
parameterized_class
from
torchaudio_unittest.common_utils
import
(
disabledInCI
,
get_asset_path
,
get_image
,
get_sinusoid
,
get_wav_data
,
is_ffmpeg_available
,
nested_params
,
...
...
@@ -12,24 +16,68 @@ from torchaudio_unittest.common_utils import (
save_image
,
save_wav
,
skipIfNoFFmpeg
,
skipIfNoHWAccel
,
TempDirMixin
,
TorchaudioTestCase
,
)
if
is_ffmpeg_available
():
from
torchaudio.io
import
(
StreamReader
,
StreamReaderSourceAudioStream
,
StreamReaderSourceStream
,
StreamReaderSourceVideoStream
,
from
torchaudio.io
import
StreamReader
,
StreamWriter
from
torchaudio.io._stream_reader
import
(
ChunkTensor
,
OutputAudioStream
,
OutputVideoStream
,
SourceAudioStream
,
SourceStream
,
SourceVideoStream
,
)
@
skipIfNoFFmpeg
class
ChunkTensorTest
(
TorchaudioTestCase
):
def
test_chunktensor
(
self
):
"""ChunkTensor serves as a replacement of tensor"""
data
=
torch
.
randn
((
256
,
2
))
pts
=
16.0
c
=
ChunkTensor
(
data
,
pts
)
assert
c
.
pts
==
pts
self
.
assertEqual
(
c
,
data
)
# method
sum_
=
c
.
sum
()
assert
isinstance
(
sum_
,
torch
.
Tensor
)
self
.
assertEqual
(
sum_
,
data
.
sum
())
# function form
min_
=
torch
.
min
(
c
)
assert
isinstance
(
min_
,
torch
.
Tensor
)
self
.
assertEqual
(
min_
,
torch
.
min
(
data
))
# attribute
t
=
c
.
T
assert
isinstance
(
t
,
torch
.
Tensor
)
self
.
assertEqual
(
t
,
data
.
T
)
# in-place op
c
[
0
]
=
0
self
.
assertEqual
(
c
,
data
)
# pass to other C++ code
buffer
=
io
.
BytesIO
()
w
=
StreamWriter
(
buffer
,
format
=
"wav"
)
w
.
add_audio_stream
(
8000
,
2
)
with
w
.
open
():
w
.
write_audio_chunk
(
0
,
c
)
w
.
write_audio_chunk
(
0
,
c
,
c
.
pts
)
################################################################################
# Helper decorator and Mixin to duplicate the tests for fileobj
_media_source
=
parameterized_class
(
(
"test_type"
,),
[(
"str"
,),
(
"fileobj"
,)
,
(
"tensor"
,)
],
[(
"str"
,),
(
"fileobj"
,)],
class_name_func
=
lambda
cls
,
_
,
params
:
f
'
{
cls
.
__name__
}
_
{
params
[
"test_type"
]
}
'
,
)
...
...
@@ -47,13 +95,6 @@ class _MediaSourceMixin:
self
.
src
=
path
elif
self
.
test_type
==
"fileobj"
:
self
.
src
=
open
(
path
,
"rb"
)
elif
self
.
test_type
==
"tensor"
:
with
open
(
path
,
"rb"
)
as
fileobj
:
data
=
fileobj
.
read
()
self
.
src
=
torch
.
frombuffer
(
data
,
dtype
=
torch
.
uint8
)
print
(
self
.
src
.
data_ptr
())
print
(
len
(
data
))
print
(
self
.
src
.
shape
)
return
self
.
src
def
tearDown
(
self
):
...
...
@@ -112,7 +153,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
base_metadata
=
{}
expected
=
[
StreamReader
SourceVideoStream
(
SourceVideoStream
(
media_type
=
"video"
,
codec
=
"h264"
,
codec_long_name
=
"H.264 / AVC / MPEG-4 AVC / MPEG-4 part 10"
,
...
...
@@ -129,7 +170,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
height
=
180
,
frame_rate
=
25.0
,
),
StreamReader
SourceAudioStream
(
SourceAudioStream
(
media_type
=
"audio"
,
codec
=
"aac"
,
codec_long_name
=
"AAC (Advanced Audio Coding)"
,
...
...
@@ -145,7 +186,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
sample_rate
=
8000.0
,
num_channels
=
2
,
),
StreamReader
SourceStream
(
SourceStream
(
media_type
=
"subtitle"
,
codec
=
"mov_text"
,
codec_long_name
=
"MOV text"
,
...
...
@@ -158,7 +199,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
"language"
:
"eng"
,
},
),
StreamReader
SourceVideoStream
(
SourceVideoStream
(
media_type
=
"video"
,
codec
=
"h264"
,
codec_long_name
=
"H.264 / AVC / MPEG-4 AVC / MPEG-4 part 10"
,
...
...
@@ -175,7 +216,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
height
=
270
,
frame_rate
=
29.97002997002997
,
),
StreamReader
SourceAudioStream
(
SourceAudioStream
(
media_type
=
"audio"
,
codec
=
"aac"
,
codec_long_name
=
"AAC (Advanced Audio Coding)"
,
...
...
@@ -191,7 +232,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
sample_rate
=
16000.0
,
num_channels
=
2
,
),
StreamReader
SourceStream
(
SourceStream
(
media_type
=
"subtitle"
,
codec
=
"mov_text"
,
codec_long_name
=
"MOV text"
,
...
...
@@ -208,6 +249,98 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
output
=
[
s
.
get_src_stream_info
(
i
)
for
i
in
range
(
6
)]
assert
expected
==
output
def
test_output_info
(
self
):
s
=
StreamReader
(
self
.
get_src
())
s
.
add_audio_stream
(
-
1
)
s
.
add_audio_stream
(
-
1
,
filter_desc
=
"aresample=8000"
)
s
.
add_audio_stream
(
-
1
,
filter_desc
=
"aformat=sample_fmts=s16p"
)
s
.
add_video_stream
(
-
1
)
s
.
add_video_stream
(
-
1
,
filter_desc
=
"fps=10"
)
s
.
add_video_stream
(
-
1
,
filter_desc
=
"format=rgb24"
)
s
.
add_video_stream
(
-
1
,
filter_desc
=
"scale=w=160:h=90"
)
# Note:
# Somehow only FFmpeg 5 reports invalid video frame rate. (24576/0)
# FFmpeg 4 and 6 work fine.
# Perhaps this is a regression in FFmpeg or it could actually originate
# from other libraries.
# It consistently fails with FFmpeg installed via conda, so we change
# the value based on FFmpeg version.
ver
=
torchaudio
.
utils
.
ffmpeg_utils
.
get_versions
()[
"libavutil"
]
print
(
ver
)
major
,
minor
,
_
=
ver
if
major
==
57
:
video_frame_rate
=
-
1
else
:
video_frame_rate
=
30000
/
1001
print
(
video_frame_rate
)
expected
=
[
OutputAudioStream
(
source_index
=
4
,
filter_description
=
"anull"
,
media_type
=
"audio"
,
format
=
"fltp"
,
sample_rate
=
16000.0
,
num_channels
=
2
,
),
OutputAudioStream
(
source_index
=
4
,
filter_description
=
"aresample=8000"
,
media_type
=
"audio"
,
format
=
"fltp"
,
sample_rate
=
8000.0
,
num_channels
=
2
,
),
OutputAudioStream
(
source_index
=
4
,
filter_description
=
"aformat=sample_fmts=s16p"
,
media_type
=
"audio"
,
format
=
"s16p"
,
sample_rate
=
16000.0
,
num_channels
=
2
,
),
OutputVideoStream
(
source_index
=
3
,
filter_description
=
"null"
,
media_type
=
"video"
,
format
=
"yuv420p"
,
width
=
480
,
height
=
270
,
frame_rate
=
30000
/
1001
,
),
OutputVideoStream
(
source_index
=
3
,
filter_description
=
"fps=10"
,
media_type
=
"video"
,
format
=
"yuv420p"
,
width
=
480
,
height
=
270
,
frame_rate
=
10
,
),
OutputVideoStream
(
source_index
=
3
,
filter_description
=
"format=rgb24"
,
media_type
=
"video"
,
format
=
"rgb24"
,
width
=
480
,
height
=
270
,
frame_rate
=
30000
/
1001
,
),
OutputVideoStream
(
source_index
=
3
,
filter_description
=
"scale=w=160:h=90"
,
media_type
=
"video"
,
format
=
"yuv420p"
,
width
=
160
,
height
=
90
,
frame_rate
=
30000
/
1001
,
),
]
output
=
[
s
.
get_out_stream_info
(
i
)
for
i
in
range
(
s
.
num_out_streams
)]
assert
expected
==
output
def
test_id3tag
(
self
):
"""get_metadata method can fetch id3tag properly"""
s
=
StreamReader
(
self
.
get_src
(
"steam-train-whistle-daniel_simon.mp3"
))
...
...
@@ -418,15 +551,26 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
if
i
>=
40
:
break
def
test_seek
(
self
):
def
test_stream_requires_grad_false
(
self
):
"""Tensors produced by StreamReader are requires_grad=False"""
s
=
StreamReader
(
self
.
get_src
())
s
.
add_basic_audio_stream
(
frames_per_chunk
=
2000
)
s
.
add_basic_video_stream
(
frames_per_chunk
=
15
)
s
.
fill_buffer
()
audio
,
video
=
s
.
pop_chunks
()
assert
not
audio
.
_elem
.
requires_grad
assert
not
video
.
_elem
.
requires_grad
@
parameterized
.
expand
([
"key"
,
"any"
,
"precise"
])
def
test_seek
(
self
,
mode
):
"""Calling `seek` multiple times should not segfault"""
s
=
StreamReader
(
self
.
get_src
())
for
i
in
range
(
10
):
s
.
seek
(
i
)
s
.
seek
(
i
,
mode
)
for
_
in
range
(
0
):
s
.
seek
(
0
)
s
.
seek
(
0
,
mode
)
for
i
in
range
(
10
,
0
,
-
1
):
s
.
seek
(
i
)
s
.
seek
(
i
,
mode
)
def
test_seek_negative
(
self
):
"""Calling `seek` with negative value should raise an exception"""
...
...
@@ -434,6 +578,232 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
with
self
.
assertRaises
(
RuntimeError
):
s
.
seek
(
-
1.0
)
def
test_seek_invalid_mode
(
self
):
"""Calling `seek` with an invalid model should raise an exception"""
s
=
StreamReader
(
self
.
get_src
())
with
self
.
assertRaises
(
ValueError
):
s
.
seek
(
10
,
"magic_seek"
)
@
parameterized
.
expand
(
[
# Test keyframe seek
# The source mp4 video has two key frames the first frame and 203rd frame at 8.08 second.
# If the seek time stamp is smaller than 8.08, it will seek into the first frame at 0.0 second.
(
"nasa_13013.mp4"
,
"key"
,
0.2
,
(
0
,
slice
(
None
))),
(
"nasa_13013.mp4"
,
"key"
,
8.04
,
(
0
,
slice
(
None
))),
(
"nasa_13013.mp4"
,
"key"
,
8.08
,
(
0
,
slice
(
202
,
None
))),
(
"nasa_13013.mp4"
,
"key"
,
8.12
,
(
0
,
slice
(
202
,
None
))),
# The source avi video has one keyframe every twelve frames 0, 12, 24,.. or every 0.4004 seconds.
# if we seek to a time stamp smaller than 0.4004 it will seek into the first frame at 0.0 second.
(
"nasa_13013.avi"
,
"key"
,
0.2
,
(
0
,
slice
(
None
))),
(
"nasa_13013.avi"
,
"key"
,
1.01
,
(
0
,
slice
(
24
,
None
))),
(
"nasa_13013.avi"
,
"key"
,
7.37
,
(
0
,
slice
(
216
,
None
))),
(
"nasa_13013.avi"
,
"key"
,
7.7
,
(
0
,
slice
(
216
,
None
))),
# Test precise seek
(
"nasa_13013.mp4"
,
"precise"
,
0.0
,
(
0
,
slice
(
None
))),
(
"nasa_13013.mp4"
,
"precise"
,
0.2
,
(
0
,
slice
(
5
,
None
))),
(
"nasa_13013.mp4"
,
"precise"
,
8.04
,
(
0
,
slice
(
201
,
None
))),
(
"nasa_13013.mp4"
,
"precise"
,
8.08
,
(
0
,
slice
(
202
,
None
))),
(
"nasa_13013.mp4"
,
"precise"
,
8.12
,
(
0
,
slice
(
203
,
None
))),
(
"nasa_13013.avi"
,
"precise"
,
0.0
,
(
0
,
slice
(
None
))),
(
"nasa_13013.avi"
,
"precise"
,
0.2
,
(
0
,
slice
(
1
,
None
))),
(
"nasa_13013.avi"
,
"precise"
,
8.1
,
(
0
,
slice
(
238
,
None
))),
(
"nasa_13013.avi"
,
"precise"
,
8.14
,
(
0
,
slice
(
239
,
None
))),
(
"nasa_13013.avi"
,
"precise"
,
8.17
,
(
0
,
slice
(
240
,
None
))),
# Test precise seek on video with missing PTS
(
"RATRACE_wave_f_nm_np1_fr_goo_37.avi"
,
"precise"
,
0.0
,
(
0
,
slice
(
None
))),
(
"RATRACE_wave_f_nm_np1_fr_goo_37.avi"
,
"precise"
,
0.2
,
(
0
,
slice
(
4
,
None
))),
(
"RATRACE_wave_f_nm_np1_fr_goo_37.avi"
,
"precise"
,
0.3
,
(
0
,
slice
(
7
,
None
))),
# Test any seek
# The source avi video has one keyframe every twelve frames 0, 12, 24,.. or every 0.4004 seconds.
(
"nasa_13013.avi"
,
"any"
,
0.0
,
(
0
,
slice
(
None
))),
(
"nasa_13013.avi"
,
"any"
,
0.56
,
(
0
,
slice
(
12
,
None
))),
(
"nasa_13013.avi"
,
"any"
,
7.77
,
(
0
,
slice
(
228
,
None
))),
(
"nasa_13013.avi"
,
"any"
,
0.2002
,
(
11
,
slice
(
12
,
None
))),
(
"nasa_13013.avi"
,
"any"
,
0.233567
,
(
10
,
slice
(
12
,
None
))),
(
"nasa_13013.avi"
,
"any"
,
0.266933
,
(
9
,
slice
(
12
,
None
))),
]
)
def
test_seek_modes
(
self
,
src
,
mode
,
seek_time
,
ref_indices
):
"""We expect the following behaviour from the diferent kinds of seek:
- `key`: the reader will seek to the first keyframe from the timestamp given
- `precise`: the reader will seek to the first keyframe from the timestamp given
and start decoding from that position until the given timestmap (discarding all frames in between)
- `any`: the reader will seek to the colsest frame to the timestamp
given but if this is not a keyframe, the content will be the delta from other frames
To thest this behaviour we can parameterize the test with the tupple ref_indices. ref_indices[0]
is the expected index on the frames list decoded after seek and ref_indices[1] is exepected index for
the list of all frames decoded from the begining (reference frames). This test checks if
the reference frame at index ref_indices[1] is the same as ref_indices[0]. Plese note that with `any`
and `key` seek we only compare keyframes, but with `precise` seek we can compare any frame content.
"""
# Using the first video stream (which is not default video stream)
stream_index
=
0
# Decode all frames for reference
src_bin
=
self
.
get_src
(
src
)
s
=
StreamReader
(
src_bin
)
s
.
add_basic_video_stream
(
-
1
,
stream_index
=
stream_index
)
s
.
process_all_packets
()
(
ref_frames
,)
=
s
.
pop_chunks
()
s
.
seek
(
seek_time
,
mode
=
mode
)
s
.
process_all_packets
()
(
frame
,)
=
s
.
pop_chunks
()
hyp_index
,
ref_index
=
ref_indices
hyp
,
ref
=
frame
[
hyp_index
:],
ref_frames
[
ref_index
]
print
(
hyp
.
shape
,
ref
.
shape
)
self
.
assertEqual
(
hyp
,
ref
)
@
parameterized
.
expand
(
[
(
"nasa_13013.mp4"
,
[
195
,
3
,
270
,
480
]),
# RATRACE does not have valid PTS metadata.
(
"RATRACE_wave_f_nm_np1_fr_goo_37.avi"
,
[
36
,
3
,
240
,
560
]),
]
)
def
test_change_fps
(
self
,
src
,
shape
):
"""Can change the FPS of videos"""
tgt_frame_rate
=
15
s
=
StreamReader
(
self
.
get_src
(
src
))
info
=
s
.
get_src_stream_info
(
s
.
default_video_stream
)
assert
info
.
frame_rate
!=
tgt_frame_rate
s
.
add_basic_video_stream
(
frames_per_chunk
=-
1
,
frame_rate
=
tgt_frame_rate
)
s
.
process_all_packets
()
(
chunk
,)
=
s
.
pop_chunks
()
assert
chunk
.
shape
==
torch
.
Size
(
shape
)
def
test_invalid_chunk_option
(
self
):
"""Passing invalid `frames_per_chunk` and `buffer_chunk_size` raises error"""
s
=
StreamReader
(
self
.
get_src
())
for
fpc
,
bcs
in
((
0
,
3
),
(
3
,
0
),
(
-
2
,
3
),
(
3
,
-
2
)):
with
self
.
assertRaises
(
RuntimeError
):
s
.
add_audio_stream
(
frames_per_chunk
=
fpc
,
buffer_chunk_size
=
bcs
)
with
self
.
assertRaises
(
RuntimeError
):
s
.
add_video_stream
(
frames_per_chunk
=
fpc
,
buffer_chunk_size
=
bcs
)
def
test_unchunked_stream
(
self
):
"""`frames_per_chunk=-1` disable chunking.
When chunking is disabled, frames contained in one AVFrame become one chunk.
For video, that is always one frame, but for audio, it depends.
"""
s
=
StreamReader
(
self
.
get_src
())
s
.
add_video_stream
(
frames_per_chunk
=-
1
,
buffer_chunk_size
=
10000
)
s
.
add_audio_stream
(
frames_per_chunk
=-
1
,
buffer_chunk_size
=
10000
)
s
.
process_all_packets
()
video
,
audio
=
s
.
pop_chunks
()
assert
video
.
shape
==
torch
.
Size
([
390
,
3
,
270
,
480
])
assert
audio
.
shape
==
torch
.
Size
([
208896
,
2
])
@
parameterized
.
expand
([(
1
,),
(
3
,),
(
5
,),
(
10
,)])
def
test_frames_per_chunk
(
self
,
fpc
):
"""Changing frames_per_chunk does not change the returned content"""
src
=
self
.
get_src
()
s
=
StreamReader
(
src
)
s
.
add_video_stream
(
frames_per_chunk
=-
1
,
buffer_chunk_size
=-
1
)
s
.
add_audio_stream
(
frames_per_chunk
=-
1
,
buffer_chunk_size
=-
1
)
s
.
process_all_packets
()
ref_video
,
ref_audio
=
s
.
pop_chunks
()
if
self
.
test_type
==
"fileobj"
:
src
.
seek
(
0
)
s
=
StreamReader
(
src
)
s
.
add_video_stream
(
frames_per_chunk
=
fpc
,
buffer_chunk_size
=-
1
)
s
.
add_audio_stream
(
frames_per_chunk
=
fpc
,
buffer_chunk_size
=-
1
)
chunks
=
list
(
s
.
stream
())
video_chunks
=
torch
.
cat
([
c
[
0
]
for
c
in
chunks
if
c
[
0
]
is
not
None
])
audio_chunks
=
torch
.
cat
([
c
[
1
]
for
c
in
chunks
if
c
[
1
]
is
not
None
])
self
.
assertEqual
(
ref_video
,
video_chunks
)
self
.
assertEqual
(
ref_audio
,
audio_chunks
)
def
test_buffer_chunk_size
(
self
):
"""`buffer_chunk_size=-1` does not drop frames."""
src
=
self
.
get_src
()
s
=
StreamReader
(
src
)
s
.
add_video_stream
(
frames_per_chunk
=
30
,
buffer_chunk_size
=-
1
)
s
.
add_audio_stream
(
frames_per_chunk
=
16000
,
buffer_chunk_size
=-
1
)
s
.
process_all_packets
()
for
_
in
range
(
13
):
video
,
audio
=
s
.
pop_chunks
()
assert
video
.
shape
==
torch
.
Size
([
30
,
3
,
270
,
480
])
assert
audio
.
shape
==
torch
.
Size
([
16000
,
2
])
video
,
audio
=
s
.
pop_chunks
()
assert
video
is
None
assert
audio
.
shape
==
torch
.
Size
([
896
,
2
])
if
self
.
test_type
==
"fileobj"
:
src
.
seek
(
0
)
s
=
StreamReader
(
src
)
s
.
add_video_stream
(
frames_per_chunk
=
30
,
buffer_chunk_size
=
3
)
s
.
add_audio_stream
(
frames_per_chunk
=
16000
,
buffer_chunk_size
=
3
)
s
.
process_all_packets
()
for
_
in
range
(
2
):
video
,
audio
=
s
.
pop_chunks
()
assert
video
.
shape
==
torch
.
Size
([
30
,
3
,
270
,
480
])
assert
audio
.
shape
==
torch
.
Size
([
16000
,
2
])
video
,
audio
=
s
.
pop_chunks
()
assert
video
.
shape
==
torch
.
Size
([
30
,
3
,
270
,
480
])
assert
audio
.
shape
==
torch
.
Size
([
896
,
2
])
@
parameterized
.
expand
([(
1
,),
(
3
,),
(
5
,),
(
10
,)])
def
test_video_pts
(
self
,
fpc
):
"""PTS values of the first frame are reported in .pts attribute"""
rate
,
num_frames
=
30000
/
1001
,
390
ref_pts
=
[
i
/
rate
for
i
in
range
(
0
,
num_frames
,
fpc
)]
s
=
StreamReader
(
self
.
get_src
())
s
.
add_video_stream
(
fpc
)
pts
=
[
video
.
pts
for
video
,
in
s
.
stream
()]
self
.
assertEqual
(
pts
,
ref_pts
)
@
parameterized
.
expand
([(
256
,),
(
512
,),
(
1024
,),
(
4086
,)])
def
test_audio_pts
(
self
,
fpc
):
"""PTS values of the first frame are reported in .pts attribute"""
rate
,
num_frames
=
16000
,
208896
ref_pts
=
[
i
/
rate
for
i
in
range
(
0
,
num_frames
,
fpc
)]
s
=
StreamReader
(
self
.
get_src
())
s
.
add_audio_stream
(
fpc
,
buffer_chunk_size
=-
1
)
pts
=
[
audio
.
pts
for
audio
,
in
s
.
stream
()]
self
.
assertEqual
(
pts
,
ref_pts
)
def
test_pts_unchunked_process_all
(
self
):
"""PTS is zero when loading the entire media with unchunked buffer"""
s
=
StreamReader
(
self
.
get_src
())
s
.
add_audio_stream
(
-
1
,
buffer_chunk_size
=-
1
)
s
.
add_video_stream
(
-
1
,
buffer_chunk_size
=-
1
)
s
.
process_all_packets
()
audio
,
video
=
s
.
pop_chunks
()
assert
audio
.
pts
==
0.0
assert
video
.
pts
==
0.0
assert
audio
.
size
(
0
)
==
208896
assert
video
.
size
(
0
)
==
390
def
test_pts_unchunked
(
self
):
"""PTS grows proportionally to the number of frames decoded"""
s
=
StreamReader
(
self
.
get_src
())
s
.
add_audio_stream
(
-
1
,
buffer_chunk_size
=-
1
)
s
.
add_video_stream
(
-
1
,
buffer_chunk_size
=-
1
)
num_audio_frames
,
num_video_frames
=
0
,
0
while
num_audio_frames
<
208896
and
num_video_frames
<
390
:
s
.
process_packet
()
audio
,
video
=
s
.
pop_chunks
()
if
audio
is
None
and
video
is
None
:
continue
if
audio
is
not
None
:
assert
audio
.
pts
==
num_audio_frames
/
16000
num_audio_frames
+=
audio
.
size
(
0
)
if
video
is
not
None
:
assert
video
.
pts
==
num_video_frames
*
1001
/
30000
num_video_frames
+=
video
.
size
(
0
)
def
_to_fltp
(
original
):
"""Convert Tensor to float32 with value range [-1, 1]"""
...
...
@@ -493,11 +863,84 @@ class StreamReaderAudioTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase)
if
self
.
test_type
==
"fileobj"
:
src
.
seek
(
0
)
self
.
_test_wav
(
src
,
original
,
fmt
=
None
)
# convert to float32
expected
=
_to_fltp
(
original
)
if
self
.
test_type
==
"fileobj"
:
src
.
seek
(
0
)
self
.
_test_wav
(
src
,
expected
,
fmt
=
"fltp"
)
def
test_audio_stream_format
(
self
):
"`format` argument properly changes the sample format of decoded audio"
num_channels
=
2
src
,
s32
=
self
.
get_src
(
8000
,
dtype
=
"int32"
,
num_channels
=
num_channels
)
args
=
{
"num_channels"
:
num_channels
,
"normalize"
:
False
,
"channels_first"
:
False
,
"num_frames"
:
1
<<
16
,
}
u8
=
get_wav_data
(
"uint8"
,
**
args
)
s16
=
get_wav_data
(
"int16"
,
**
args
)
s64
=
s32
.
to
(
torch
.
int64
)
*
(
1
<<
32
)
f32
=
get_wav_data
(
"float32"
,
**
args
)
f64
=
get_wav_data
(
"float64"
,
**
args
)
s
=
StreamReader
(
src
)
s
.
add_basic_audio_stream
(
frames_per_chunk
=-
1
,
format
=
"u8"
)
s
.
add_basic_audio_stream
(
frames_per_chunk
=-
1
,
format
=
"u8p"
)
s
.
add_basic_audio_stream
(
frames_per_chunk
=-
1
,
format
=
"s16"
)
s
.
add_basic_audio_stream
(
frames_per_chunk
=-
1
,
format
=
"s16p"
)
s
.
add_basic_audio_stream
(
frames_per_chunk
=-
1
,
format
=
"s32"
)
s
.
add_basic_audio_stream
(
frames_per_chunk
=-
1
,
format
=
"s32p"
)
s
.
add_basic_audio_stream
(
frames_per_chunk
=-
1
,
format
=
"s64"
)
s
.
add_basic_audio_stream
(
frames_per_chunk
=-
1
,
format
=
"s64p"
)
s
.
add_basic_audio_stream
(
frames_per_chunk
=-
1
,
format
=
"flt"
)
s
.
add_basic_audio_stream
(
frames_per_chunk
=-
1
,
format
=
"fltp"
)
s
.
add_basic_audio_stream
(
frames_per_chunk
=-
1
,
format
=
"dbl"
)
s
.
add_basic_audio_stream
(
frames_per_chunk
=-
1
,
format
=
"dblp"
)
s
.
process_all_packets
()
chunks
=
s
.
pop_chunks
()
self
.
assertEqual
(
chunks
[
0
],
u8
,
atol
=
1
,
rtol
=
0
)
self
.
assertEqual
(
chunks
[
1
],
u8
,
atol
=
1
,
rtol
=
0
)
self
.
assertEqual
(
chunks
[
2
],
s16
)
self
.
assertEqual
(
chunks
[
3
],
s16
)
self
.
assertEqual
(
chunks
[
4
],
s32
)
self
.
assertEqual
(
chunks
[
5
],
s32
)
self
.
assertEqual
(
chunks
[
6
],
s64
)
self
.
assertEqual
(
chunks
[
7
],
s64
)
self
.
assertEqual
(
chunks
[
8
],
f32
)
self
.
assertEqual
(
chunks
[
9
],
f32
)
self
.
assertEqual
(
chunks
[
10
],
f64
)
self
.
assertEqual
(
chunks
[
11
],
f64
)
@
nested_params
([
4000
,
16000
])
def
test_basic_audio_stream_sample_rate
(
self
,
sr
):
"""`sample_rate` argument changes the sample_rate of decoded audio"""
src_num_channels
,
src_sr
=
2
,
8000
data
=
get_sinusoid
(
sample_rate
=
src_sr
,
n_channels
=
src_num_channels
,
channels_first
=
False
)
path
=
self
.
get_temp_path
(
"ref.wav"
)
save_wav
(
path
,
data
,
src_sr
,
channels_first
=
False
)
s
=
StreamReader
(
path
)
s
.
add_basic_audio_stream
(
frames_per_chunk
=-
1
,
format
=
"flt"
,
sample_rate
=
sr
)
self
.
assertEqual
(
s
.
get_src_stream_info
(
0
).
sample_rate
,
src_sr
)
self
.
assertEqual
(
s
.
get_out_stream_info
(
0
).
sample_rate
,
sr
)
s
.
process_all_packets
()
(
chunks
,)
=
s
.
pop_chunks
()
self
.
assertEqual
(
chunks
.
shape
,
[
sr
,
src_num_channels
])
@
nested_params
([
1
,
2
,
3
,
8
,
16
])
def
test_basic_audio_stream_num_channels
(
self
,
num_channels
):
"""`sample_rate` argument changes the number of channels of decoded audio"""
src_num_channels
,
sr
=
2
,
8000
data
=
get_sinusoid
(
sample_rate
=
sr
,
n_channels
=
src_num_channels
,
channels_first
=
False
)
path
=
self
.
get_temp_path
(
"ref.wav"
)
save_wav
(
path
,
data
,
sr
,
channels_first
=
False
)
s
=
StreamReader
(
path
)
s
.
add_basic_audio_stream
(
frames_per_chunk
=-
1
,
format
=
"flt"
,
num_channels
=
num_channels
)
self
.
assertEqual
(
s
.
get_src_stream_info
(
0
).
num_channels
,
src_num_channels
)
self
.
assertEqual
(
s
.
get_out_stream_info
(
0
).
num_channels
,
num_channels
)
s
.
process_all_packets
()
(
chunks
,)
=
s
.
pop_chunks
()
self
.
assertEqual
(
chunks
.
shape
,
[
sr
,
num_channels
])
@
nested_params
(
[
"int16"
,
"uint8"
,
"int32"
],
# "float", "double", "int64"]
...
...
@@ -630,23 +1073,192 @@ class StreamReaderImageTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase)
rgb
=
torch
.
empty
(
1
,
3
,
256
,
256
,
dtype
=
torch
.
uint8
)
rgb
[
0
,
0
]
=
torch
.
arange
(
256
,
dtype
=
torch
.
uint8
).
reshape
([
1
,
-
1
])
rgb
[
0
,
1
]
=
torch
.
arange
(
256
,
dtype
=
torch
.
uint8
).
reshape
([
-
1
,
1
])
alpha
=
torch
.
full
((
1
,
1
,
256
,
256
),
255
,
dtype
=
torch
.
uint8
)
for
i
in
range
(
256
):
rgb
[
0
,
2
]
=
i
path
=
self
.
get_temp_path
(
f
"ref_
{
i
}
.png"
)
save_image
(
path
,
rgb
[
0
],
mode
=
"RGB"
)
rgb16
=
((
rgb
.
to
(
torch
.
int32
)
-
128
)
<<
8
).
to
(
torch
.
int16
)
yuv
=
rgb_to_yuv_ccir
(
rgb
)
yuv16
=
yuv
.
to
(
torch
.
int16
)
*
4
bgr
=
rgb
[:,
[
2
,
1
,
0
],
:,
:]
gray
=
rgb_to_gray
(
rgb
)
argb
=
torch
.
cat
([
alpha
,
rgb
],
dim
=
1
)
rgba
=
torch
.
cat
([
rgb
,
alpha
],
dim
=
1
)
abgr
=
torch
.
cat
([
alpha
,
bgr
],
dim
=
1
)
bgra
=
torch
.
cat
([
bgr
,
alpha
],
dim
=
1
)
s
=
StreamReader
(
path
)
s
.
add_basic_video_stream
(
frames_per_chunk
=-
1
,
format
=
"yuv444p"
)
s
.
add_basic_video_stream
(
frames_per_chunk
=-
1
,
format
=
"yuv420p"
)
s
.
add_basic_video_stream
(
frames_per_chunk
=-
1
,
format
=
"nv12"
)
s
.
add_basic_video_stream
(
frames_per_chunk
=-
1
,
format
=
"rgb24"
)
s
.
add_basic_video_stream
(
frames_per_chunk
=-
1
,
format
=
"bgr24"
)
s
.
add_basic_video_stream
(
frames_per_chunk
=-
1
,
format
=
"gray8"
)
s
.
add_basic_video_stream
(
frames_per_chunk
=-
1
,
format
=
"rgb48le"
)
s
.
add_basic_video_stream
(
frames_per_chunk
=-
1
,
format
=
"argb"
)
s
.
add_basic_video_stream
(
frames_per_chunk
=-
1
,
format
=
"rgba"
)
s
.
add_basic_video_stream
(
frames_per_chunk
=-
1
,
format
=
"abgr"
)
s
.
add_basic_video_stream
(
frames_per_chunk
=-
1
,
format
=
"bgra"
)
s
.
add_basic_video_stream
(
frames_per_chunk
=-
1
,
format
=
"yuv420p10le"
)
s
.
process_all_packets
()
output_yuv
,
output_rgb
,
output_bgr
,
output_gray
=
s
.
pop_chunks
()
self
.
assertEqual
(
yuv
,
output_yuv
,
atol
=
1
,
rtol
=
0
)
self
.
assertEqual
(
rgb
,
output_rgb
,
atol
=
0
,
rtol
=
0
)
self
.
assertEqual
(
bgr
,
output_bgr
,
atol
=
0
,
rtol
=
0
)
self
.
assertEqual
(
gray
,
output_gray
,
atol
=
1
,
rtol
=
0
)
chunks
=
s
.
pop_chunks
()
self
.
assertEqual
(
chunks
[
0
],
yuv
,
atol
=
1
,
rtol
=
0
)
self
.
assertEqual
(
chunks
[
1
],
yuv
,
atol
=
1
,
rtol
=
0
)
self
.
assertEqual
(
chunks
[
2
],
yuv
,
atol
=
1
,
rtol
=
0
)
self
.
assertEqual
(
chunks
[
3
],
rgb
,
atol
=
0
,
rtol
=
0
)
self
.
assertEqual
(
chunks
[
4
],
bgr
,
atol
=
0
,
rtol
=
0
)
self
.
assertEqual
(
chunks
[
5
],
gray
,
atol
=
1
,
rtol
=
0
)
self
.
assertEqual
(
chunks
[
6
],
rgb16
,
atol
=
256
,
rtol
=
0
)
self
.
assertEqual
(
chunks
[
7
],
argb
,
atol
=
0
,
rtol
=
0
)
self
.
assertEqual
(
chunks
[
8
],
rgba
,
atol
=
0
,
rtol
=
0
)
self
.
assertEqual
(
chunks
[
9
],
abgr
,
atol
=
0
,
rtol
=
0
)
self
.
assertEqual
(
chunks
[
10
],
bgra
,
atol
=
0
,
rtol
=
0
)
self
.
assertEqual
(
chunks
[
11
],
yuv16
,
atol
=
4
,
rtol
=
0
)
@
skipIfNoHWAccel
(
"h264_cuvid"
)
class
CuvidHWAccelInterfaceTest
(
TorchaudioTestCase
):
def
test_dup_hw_acel
(
self
):
"""Specifying the same source stream with and without HW accel should fail (instead of segfault later)"""
src
=
get_asset_path
(
"nasa_13013.mp4"
)
r
=
StreamReader
(
src
)
r
.
add_video_stream
(
-
1
,
decoder
=
"h264_cuvid"
)
with
self
.
assertRaises
(
RuntimeError
):
r
.
add_video_stream
(
-
1
,
decoder
=
"h264_cuvid"
,
hw_accel
=
"cuda"
)
r
=
StreamReader
(
src
)
r
.
add_video_stream
(
-
1
,
decoder
=
"h264_cuvid"
,
hw_accel
=
"cuda"
)
with
self
.
assertRaises
(
RuntimeError
):
r
.
add_video_stream
(
-
1
,
decoder
=
"h264_cuvid"
)
@
_media_source
class
CudaDecoderTest
(
_MediaSourceMixin
,
TempDirMixin
,
TorchaudioTestCase
):
def
_test_decode
(
self
,
decoder
:
str
,
src_path
:
str
,
height
:
int
,
width
:
int
,
ref_num_frames
:
int
,
hw_accel
=
None
,
decoder_option
=
None
,
dtype
:
torch
.
dtype
=
torch
.
uint8
,
):
src
=
self
.
get_src
(
get_asset_path
(
src_path
))
r
=
StreamReader
(
src
)
r
.
add_video_stream
(
10
,
decoder
=
decoder
,
decoder_option
=
decoder_option
,
hw_accel
=
hw_accel
)
num_frames
=
0
for
(
chunk
,)
in
r
.
stream
():
self
.
assertEqual
(
chunk
.
device
,
torch
.
device
(
hw_accel
or
"cpu"
))
self
.
assertEqual
(
chunk
.
dtype
,
dtype
)
self
.
assertEqual
(
chunk
.
shape
,
torch
.
Size
([
10
,
3
,
height
,
width
]))
num_frames
+=
chunk
.
size
(
0
)
assert
num_frames
==
ref_num_frames
@
skipIfNoHWAccel
(
"h264_cuvid"
)
def
test_h264_cuvid
(
self
):
"""GPU decoder works for H264"""
self
.
_test_decode
(
"h264_cuvid"
,
"nasa_13013.mp4"
,
270
,
480
,
390
)
@
skipIfNoHWAccel
(
"h264_cuvid"
)
def
test_h264_cuvid_hw_accel
(
self
):
"""GPU decoder works for H264 with HW acceleration, and put the frames on CUDA tensor"""
self
.
_test_decode
(
"h264_cuvid"
,
"nasa_13013.mp4"
,
270
,
480
,
390
,
hw_accel
=
"cuda:0"
)
@
skipIfNoHWAccel
(
"h264_cuvid"
)
def
test_h264_cuvid_hw_accel_resize
(
self
):
"""GPU decoder works for H264 with HW acceleration and resize option"""
w
,
h
=
240
,
136
self
.
_test_decode
(
"h264_cuvid"
,
"nasa_13013.mp4"
,
h
,
w
,
390
,
hw_accel
=
"cuda:0"
,
decoder_option
=
{
"resize"
:
f
"
{
w
}
x
{
h
}
"
}
)
@
skipIfNoHWAccel
(
"h264_cuvid"
)
def
test_h264_cuvid_hw_accel_crop
(
self
):
"""GPU decoder works for H264 with HW acceleration and crop option"""
top
,
bottom
,
left
,
right
=
3
,
5
,
7
,
9
self
.
_test_decode
(
"h264_cuvid"
,
"nasa_13013.mp4"
,
262
,
464
,
390
,
hw_accel
=
"cuda:0"
,
decoder_option
=
{
"crop"
:
f
"
{
top
}
x
{
bottom
}
x
{
left
}
x
{
right
}
"
},
)
@
skipIfNoHWAccel
(
"hevc_cuvid"
)
def
test_hevc_cuvid
(
self
):
"""GPU decoder works for H265/HEVC"""
self
.
_test_decode
(
"hevc_cuvid"
,
"testsrc.hevc"
,
144
,
256
,
300
)
@
skipIfNoHWAccel
(
"hevc_cuvid"
)
def
test_hevc_cuvid_hw_accel
(
self
):
"""GPU decoder works for H265/HEVC with HW acceleration, and put the frames on CUDA tensor"""
self
.
_test_decode
(
"hevc_cuvid"
,
"testsrc.hevc"
,
144
,
256
,
300
,
hw_accel
=
"cuda:0"
,
dtype
=
torch
.
int16
)
@
skipIfNoHWAccel
(
"hevc_cuvid"
)
def
test_hevc_cuvid_hw_accel_resize
(
self
):
"""GPU decoder works for H265/HEVC with HW acceleration and resize option"""
w
,
h
=
128
,
64
self
.
_test_decode
(
"hevc_cuvid"
,
"testsrc.hevc"
,
h
,
w
,
300
,
hw_accel
=
"cuda:0"
,
dtype
=
torch
.
int16
,
decoder_option
=
{
"resize"
:
f
"
{
w
}
x
{
h
}
"
},
)
@
skipIfNoHWAccel
(
"hevc_cuvid"
)
def
test_hevc_cuvid_hw_accel_crop
(
self
):
"""GPU decoder works for H265/HEVC with HW acceleration and crop option"""
top
,
bottom
,
left
,
right
=
3
,
5
,
7
,
9
self
.
_test_decode
(
"hevc_cuvid"
,
"testsrc.hevc"
,
136
,
240
,
300
,
hw_accel
=
"cuda:0"
,
dtype
=
torch
.
int16
,
decoder_option
=
{
"crop"
:
f
"
{
top
}
x
{
bottom
}
x
{
left
}
x
{
right
}
"
},
)
@
skipIfNoHWAccel
(
"h264_cuvid"
)
# Disabled in CI: https://github.com/pytorch/audio/issues/3376
@
disabledInCI
class
FilterGraphWithCudaAccel
(
TorchaudioTestCase
):
def
test_sclae_cuda_change_size
(
self
):
"""scale_cuda filter can be used when HW accel is on"""
src
=
get_asset_path
(
"nasa_13013.mp4"
)
r
=
StreamReader
(
src
)
r
.
add_video_stream
(
10
,
decoder
=
"h264_cuvid"
,
hw_accel
=
"cuda"
,
filter_desc
=
"scale_cuda=iw/2:ih/2"
)
num_frames
=
0
for
(
chunk
,)
in
r
.
stream
():
self
.
assertEqual
(
chunk
.
device
,
torch
.
device
(
"cuda:0"
))
self
.
assertEqual
(
chunk
.
dtype
,
torch
.
uint8
)
self
.
assertEqual
(
chunk
.
shape
,
torch
.
Size
([
10
,
3
,
135
,
240
]))
num_frames
+=
chunk
.
size
(
0
)
assert
num_frames
==
390
def
test_scale_cuda_format
(
self
):
"""yuv444p format conversion should work"""
src
=
get_asset_path
(
"nasa_13013.mp4"
)
r
=
StreamReader
(
src
)
r
.
add_video_stream
(
10
,
decoder
=
"h264_cuvid"
,
hw_accel
=
"cuda"
,
filter_desc
=
"scale_cuda=format=yuv444p"
)
num_frames
=
0
for
(
chunk
,)
in
r
.
stream
():
self
.
assertEqual
(
chunk
.
device
,
torch
.
device
(
"cuda:0"
))
self
.
assertEqual
(
chunk
.
dtype
,
torch
.
uint8
)
self
.
assertEqual
(
chunk
.
shape
,
torch
.
Size
([
10
,
3
,
270
,
480
]))
num_frames
+=
chunk
.
size
(
0
)
assert
num_frames
==
390
test/torchaudio_unittest/io/stream_writer_test.py
View file @
ffeba11a
import
io
import
math
import
torch
import
torchaudio
from
parameterized
import
parameterized
,
parameterized_class
from
torchaudio_unittest.common_utils
import
(
get_asset_path
,
get_sinusoid
,
is_ffmpeg_available
,
nested_params
,
rgb_to_yuv_ccir
,
...
...
@@ -13,8 +17,10 @@ from torchaudio_unittest.common_utils import (
TorchaudioTestCase
,
)
from
.common
import
lt42
if
is_ffmpeg_available
():
from
torchaudio.io
import
StreamReader
,
StreamWriter
from
torchaudio.io
import
CodecConfig
,
StreamReader
,
StreamWriter
def
get_audio_chunk
(
fmt
,
sample_rate
,
num_channels
):
...
...
@@ -87,9 +93,21 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
def
get_dst
(
self
,
path
):
return
super
().
get_dst
(
self
.
get_temp_path
(
path
))
def
get_buf
(
self
,
path
):
with
open
(
self
.
get_temp_path
(
path
),
"rb"
)
as
fileobj
:
return
fileobj
.
read
()
def
test_unopened_error
(
self
):
"""If dst is not opened when attempting to write data, runtime error should be raised"""
path
=
self
.
get_dst
(
"test.mp4"
)
s
=
StreamWriter
(
path
,
format
=
"mp4"
)
s
.
set_metadata
(
metadata
=
{
"artist"
:
"torchaudio"
,
"title"
:
self
.
id
()})
s
.
add_audio_stream
(
sample_rate
=
16000
,
num_channels
=
2
)
s
.
add_video_stream
(
frame_rate
=
30
,
width
=
16
,
height
=
16
)
dummy
=
torch
.
zeros
((
3
,
2
))
with
self
.
assertRaises
(
RuntimeError
):
s
.
write_audio_chunk
(
0
,
dummy
)
dummy
=
torch
.
zeros
((
3
,
3
,
16
,
16
))
with
self
.
assertRaises
(
RuntimeError
):
s
.
write_video_chunk
(
1
,
dummy
)
@
skipIfNoModule
(
"tinytag"
)
def
test_metadata_overwrite
(
self
):
...
...
@@ -135,21 +153,26 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
@
parameterized
.
expand
(
[
(
"mp3"
,
8000
,
1
,
"s32p"
,
None
),
(
"mp3"
,
16000
,
2
,
"fltp"
,
None
),
(
"mp3"
,
44100
,
1
,
"s16p"
,
{
"abr"
:
"true"
}),
(
"flac"
,
8000
,
1
,
"s16"
,
None
),
(
"flac"
,
16000
,
2
,
"s32"
,
None
),
(
"opus"
,
48000
,
2
,
None
,
{
"strict"
:
"experimental"
}),
(
"adts"
,
8000
,
1
,
"fltp"
,
None
),
# AAC format
(
"mp3"
,
8000
,
1
,
None
,
"s32p"
,
None
),
(
"mp3"
,
16000
,
2
,
None
,
"fltp"
,
None
),
(
"mp3"
,
44100
,
1
,
None
,
"s16p"
,
{
"abr"
:
"true"
}),
(
"flac"
,
8000
,
1
,
None
,
"s16"
,
None
),
(
"flac"
,
16000
,
2
,
None
,
"s32"
,
None
),
(
"opus"
,
48000
,
2
,
"opus"
,
None
,
None
),
(
"ogg"
,
48000
,
2
,
"vorbis"
,
None
,
None
),
(
"adts"
,
8000
,
1
,
None
,
"fltp"
,
None
),
# AAC format
]
)
def
test_valid_audio_muxer_and_codecs
(
self
,
ext
,
sample_rate
,
num_channels
,
encoder_format
,
encoder_option
):
def
test_valid_audio_muxer_and_codecs
(
self
,
ext
,
sample_rate
,
num_channels
,
encoder
,
encoder_format
,
encoder_option
):
"""Tensor of various dtypes can be saved as given format."""
path
=
self
.
get_dst
(
f
"test.
{
ext
}
"
)
s
=
StreamWriter
(
path
,
format
=
ext
)
s
.
set_metadata
(
metadata
=
{
"artist"
:
"torchaudio"
,
"title"
:
self
.
id
()})
s
.
add_audio_stream
(
sample_rate
,
num_channels
,
encoder_option
=
encoder_option
,
encoder_format
=
encoder_format
)
s
.
add_audio_stream
(
sample_rate
,
num_channels
,
encoder
=
encoder
,
encoder_option
=
encoder_option
,
encoder_format
=
encoder_format
)
chunk
=
get_audio_chunk
(
"flt"
,
sample_rate
,
num_channels
)
with
s
.
open
():
...
...
@@ -202,6 +225,19 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
s
.
write_audio_chunk
(
0
,
audio
)
s
.
write_video_chunk
(
1
,
video
)
@
skipIfNoFFmpeg
class
StreamWriterCorrectnessTest
(
TempDirMixin
,
TorchaudioTestCase
):
@
classmethod
def
setUpClass
(
cls
):
super
().
setUpClass
()
torchaudio
.
utils
.
ffmpeg_utils
.
set_log_level
(
32
)
@
classmethod
def
tearDownClass
(
cls
):
torchaudio
.
utils
.
ffmpeg_utils
.
set_log_level
(
8
)
super
().
tearDownClass
()
@
nested_params
(
[
(
"gray8"
,
"gray8"
),
...
...
@@ -227,16 +263,16 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
chunk
=
torch
.
randint
(
low
=
0
,
high
=
255
,
size
=
src_size
,
dtype
=
torch
.
uint8
)
# Write data
dst
=
self
.
get_
dst
(
filename
)
dst
=
self
.
get_
temp_path
(
filename
)
s
=
StreamWriter
(
dst
,
format
=
"rawvideo"
)
s
.
add_video_stream
(
frame_rate
,
width
,
height
,
format
=
src_fmt
,
encoder_format
=
encoder_fmt
)
with
s
.
open
():
s
.
write_video_chunk
(
0
,
chunk
)
# Fetch the written data
if
self
.
test_
fileobj
:
dst
.
flush
()
buf
=
self
.
get_buf
(
filename
)
with
open
(
dst
,
"rb"
)
as
fileobj
:
buf
=
fileobj
.
read
()
result
=
torch
.
frombuffer
(
buf
,
dtype
=
torch
.
uint8
)
if
encoder_fmt
.
endswith
(
"p"
):
result
=
result
.
reshape
(
src_size
)
...
...
@@ -261,14 +297,12 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
h
,
w
=
resolution
# Write data
dst
=
self
.
get_
dst
(
filename
)
dst
=
self
.
get_
temp_path
(
filename
)
s
=
torchaudio
.
io
.
StreamWriter
(
dst
=
dst
,
format
=
ext
)
s
.
add_video_stream
(
frame_rate
=
framerate
,
height
=
h
,
width
=
w
,
format
=
format
)
chunk
=
torch
.
stack
([
torch
.
full
((
3
,
h
,
w
),
i
,
dtype
=
torch
.
uint8
)
for
i
in
torch
.
linspace
(
0
,
255
,
256
)])
with
s
.
open
():
s
.
write_video_chunk
(
0
,
chunk
)
if
self
.
test_fileobj
:
dst
.
flush
()
# Load data
s
=
torchaudio
.
io
.
StreamReader
(
src
=
self
.
get_temp_path
(
filename
))
...
...
@@ -293,30 +327,54 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
pass
@
nested_params
(
[
"wav"
,
"mp3"
,
"flac"
],
[
"wav"
,
"flac"
],
[
8000
,
16000
,
44100
],
[
1
,
2
],
)
def
test_audio_num_frames
(
self
,
ext
,
sample_rate
,
num_channels
):
""""""
def
test_audio_num_frames
_lossless
(
self
,
ext
,
sample_rate
,
num_channels
):
"""
Lossless format preserves the data
"""
filename
=
f
"test.
{
ext
}
"
data
=
get_sinusoid
(
sample_rate
=
sample_rate
,
n_channels
=
num_channels
,
dtype
=
"int16"
,
channels_first
=
False
)
# Write data
dst
=
self
.
get_
dst
(
filename
)
dst
=
self
.
get_
temp_path
(
filename
)
s
=
torchaudio
.
io
.
StreamWriter
(
dst
=
dst
,
format
=
ext
)
s
.
add_audio_stream
(
sample_rate
=
sample_rate
,
num_channels
=
num_channels
)
s
.
add_audio_stream
(
sample_rate
=
sample_rate
,
num_channels
=
num_channels
,
format
=
"s16"
)
with
s
.
open
():
s
.
write_audio_chunk
(
0
,
data
)
freq
=
300
duration
=
60
theta
=
torch
.
linspace
(
0
,
freq
*
2
*
3.14
*
duration
,
sample_rate
*
duration
)
if
num_channels
==
1
:
chunk
=
torch
.
sin
(
theta
).
unsqueeze
(
-
1
)
else
:
chunk
=
torch
.
stack
([
torch
.
sin
(
theta
),
torch
.
cos
(
theta
)],
dim
=-
1
)
# Load data
s
=
torchaudio
.
io
.
StreamReader
(
src
=
self
.
get_temp_path
(
filename
))
s
.
add_audio_stream
(
-
1
)
s
.
process_all_packets
()
(
saved
,)
=
s
.
pop_chunks
()
self
.
assertEqual
(
saved
,
data
)
@
parameterized
.
expand
(
[
(
"mp3"
,
1
,
8000
),
(
"mp3"
,
1
,
16000
),
(
"mp3"
,
1
,
44100
),
(
"mp3"
,
2
,
8000
),
(
"mp3"
,
2
,
16000
),
(
"mp3"
,
2
,
44100
),
(
"opus"
,
1
,
48000
),
]
)
def
test_audio_num_frames_lossy
(
self
,
ext
,
num_channels
,
sample_rate
):
"""Saving audio preserves the number of channels and frames"""
filename
=
f
"test.
{
ext
}
"
data
=
get_sinusoid
(
sample_rate
=
sample_rate
,
n_channels
=
num_channels
,
channels_first
=
False
)
# Write data
dst
=
self
.
get_temp_path
(
filename
)
s
=
torchaudio
.
io
.
StreamWriter
(
dst
=
dst
,
format
=
ext
)
s
.
add_audio_stream
(
sample_rate
=
sample_rate
,
num_channels
=
num_channels
)
with
s
.
open
():
s
.
write_audio_chunk
(
0
,
chunk
)
if
self
.
test_fileobj
:
dst
.
flush
()
s
.
write_audio_chunk
(
0
,
data
)
# Load data
s
=
torchaudio
.
io
.
StreamReader
(
src
=
self
.
get_temp_path
(
filename
))
...
...
@@ -324,9 +382,28 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
s
.
process_all_packets
()
(
saved
,)
=
s
.
pop_chunks
()
assert
saved
.
shape
==
chunk
.
shape
if
format
in
[
"wav"
,
"flac"
]:
self
.
assertEqual
(
saved
,
chunk
)
# On 4.1 OPUS produces 48312 samples (extra 312)
# this has been fixed on 4.2+
# TODO: issue warning if on 4.1?
if
ext
==
"opus"
and
lt42
():
return
self
.
assertEqual
(
saved
.
shape
,
data
.
shape
)
def
test_g722_sample_rate
(
self
):
"""Encoding G.722 properly converts sample rate to 16k"""
filename
=
"test.g722"
sample_rate
=
41000
data
=
get_sinusoid
(
sample_rate
=
sample_rate
,
n_channels
=
1
,
channels_first
=
False
)
# write data
dst
=
self
.
get_temp_path
(
filename
)
w
=
StreamWriter
(
dst
,
format
=
"g722"
)
w
.
add_audio_stream
(
sample_rate
=
sample_rate
,
num_channels
=
1
)
with
w
.
open
():
w
.
write_audio_chunk
(
0
,
data
)
r
=
StreamReader
(
src
=
self
.
get_temp_path
(
filename
))
self
.
assertEqual
(
r
.
get_src_stream_info
(
0
).
sample_rate
,
16000
)
def
test_preserve_fps
(
self
):
"""Decimal point frame rate is properly saved
...
...
@@ -339,16 +416,346 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
width
,
height
=
96
,
128
# Write data
dst
=
self
.
get_
dst
(
filename
)
dst
=
self
.
get_
temp_path
(
filename
)
writer
=
torchaudio
.
io
.
StreamWriter
(
dst
=
dst
,
format
=
ext
)
writer
.
add_video_stream
(
frame_rate
=
frame_rate
,
width
=
width
,
height
=
height
)
video
=
torch
.
randint
(
256
,
(
90
,
3
,
height
,
width
),
dtype
=
torch
.
uint8
)
with
writer
.
open
():
writer
.
write_video_chunk
(
0
,
video
)
if
self
.
test_fileobj
:
dst
.
flush
()
# Load data
reader
=
torchaudio
.
io
.
StreamReader
(
src
=
self
.
get_temp_path
(
filename
))
assert
reader
.
get_src_stream_info
(
0
).
frame_rate
==
frame_rate
def
test_video_pts_increment
(
self
):
"""PTS values increment by the inverse of frame rate"""
ext
=
"mp4"
num_frames
=
256
filename
=
f
"test.
{
ext
}
"
frame_rate
=
5000
/
167
width
,
height
=
96
,
128
# Write data
dst
=
self
.
get_temp_path
(
filename
)
writer
=
torchaudio
.
io
.
StreamWriter
(
dst
=
dst
,
format
=
ext
)
writer
.
add_video_stream
(
frame_rate
=
frame_rate
,
width
=
width
,
height
=
height
)
video
=
torch
.
randint
(
256
,
(
num_frames
,
3
,
height
,
width
),
dtype
=
torch
.
uint8
)
with
writer
.
open
():
writer
.
write_video_chunk
(
0
,
video
)
reader
=
torchaudio
.
io
.
StreamReader
(
src
=
self
.
get_temp_path
(
filename
))
reader
.
add_video_stream
(
1
)
pts
=
[
chunk
.
pts
for
(
chunk
,)
in
reader
.
stream
()]
assert
len
(
pts
)
==
num_frames
for
i
,
val
in
enumerate
(
pts
):
expected
=
i
/
frame_rate
assert
abs
(
val
-
expected
)
<
1e-10
def
test_audio_pts_increment
(
self
):
"""PTS values increment by the inverse of sample rate"""
ext
=
"wav"
filename
=
f
"test.
{
ext
}
"
sample_rate
=
8000
num_channels
=
2
# Write data
dst
=
self
.
get_temp_path
(
filename
)
writer
=
torchaudio
.
io
.
StreamWriter
(
dst
=
dst
,
format
=
ext
)
writer
.
add_audio_stream
(
sample_rate
=
sample_rate
,
num_channels
=
num_channels
)
audio
=
get_sinusoid
(
sample_rate
=
sample_rate
,
n_channels
=
num_channels
,
channels_first
=
False
)
num_frames
=
audio
.
size
(
0
)
with
writer
.
open
():
writer
.
write_audio_chunk
(
0
,
audio
)
reader
=
torchaudio
.
io
.
StreamReader
(
src
=
self
.
get_temp_path
(
filename
))
frames_per_chunk
=
sample_rate
//
4
reader
.
add_audio_stream
(
frames_per_chunk
,
-
1
)
chunks
=
[
chunk
for
(
chunk
,)
in
reader
.
stream
()]
expected
=
num_frames
//
(
frames_per_chunk
)
assert
len
(
chunks
)
==
expected
,
f
"Expected
{
expected
}
elements. Found
{
len
(
chunks
)
}
"
num_samples
=
0
for
chunk
in
chunks
:
expected
=
num_samples
/
sample_rate
num_samples
+=
chunk
.
size
(
0
)
print
(
chunk
.
pts
,
expected
)
assert
abs
(
chunk
.
pts
-
expected
)
<
1e-10
@
parameterized
.
expand
(
[
(
10
,
100
),
(
15
,
150
),
(
24
,
240
),
(
25
,
200
),
(
30
,
300
),
(
50
,
500
),
(
60
,
600
),
# PTS value conversion involves float <-> int conversion, which can
# introduce rounding error.
# This test is a spot-check for popular 29.97 Hz
(
30000
/
1001
,
10010
),
]
)
def
test_video_pts_overwrite
(
self
,
frame_rate
,
num_frames
):
"""Can overwrite PTS"""
ext
=
"mp4"
filename
=
f
"test.
{
ext
}
"
width
,
height
=
8
,
8
# Write data
dst
=
self
.
get_temp_path
(
filename
)
writer
=
torchaudio
.
io
.
StreamWriter
(
dst
=
dst
,
format
=
ext
)
writer
.
add_video_stream
(
frame_rate
=
frame_rate
,
width
=
width
,
height
=
height
)
video
=
torch
.
zeros
((
1
,
3
,
height
,
width
),
dtype
=
torch
.
uint8
)
reference_pts
=
[]
with
writer
.
open
():
for
i
in
range
(
num_frames
):
pts
=
i
/
frame_rate
reference_pts
.
append
(
pts
)
writer
.
write_video_chunk
(
0
,
video
,
pts
)
reader
=
torchaudio
.
io
.
StreamReader
(
src
=
self
.
get_temp_path
(
filename
))
reader
.
add_video_stream
(
1
)
pts
=
[
chunk
.
pts
for
(
chunk
,)
in
reader
.
stream
()]
assert
len
(
pts
)
==
len
(
reference_pts
)
for
val
,
ref
in
zip
(
pts
,
reference_pts
):
# torch provides isclose, but we don't know if converting floats to tensor
# could introduce a descrepancy, so we compare floats and use math.isclose
# for that.
assert
math
.
isclose
(
val
,
ref
)
def
test_codec_config
(
self
):
"""Can successfully set configuration and write audio."""
ext
=
"mp3"
filename
=
f
"test.
{
ext
}
"
sample_rate
=
44100
num_channels
=
2
# Write data
dst
=
self
.
get_temp_path
(
filename
)
writer
=
torchaudio
.
io
.
StreamWriter
(
dst
=
dst
,
format
=
ext
)
codec_config
=
CodecConfig
(
bit_rate
=
198_000
,
compression_level
=
3
)
writer
.
add_audio_stream
(
sample_rate
=
sample_rate
,
num_channels
=
num_channels
,
codec_config
=
codec_config
)
audio
=
torch
.
zeros
((
8000
,
2
))
with
writer
.
open
():
writer
.
write_audio_chunk
(
0
,
audio
)
def
test_codec_config_bit_rate_output
(
self
):
"""Increasing the specified bit rate yields a larger encoded output."""
ext
=
"mp3"
sample_rate
=
44100
num_channels
=
2
audio
=
torch
.
rand
((
8000
,
num_channels
))
def
write_audio
(
buffer
,
bit_rate
):
writer
=
torchaudio
.
io
.
StreamWriter
(
dst
=
buffer
,
format
=
ext
)
writer
.
add_audio_stream
(
sample_rate
=
sample_rate
,
num_channels
=
num_channels
,
codec_config
=
CodecConfig
(
bit_rate
=
bit_rate
),
)
with
writer
.
open
():
writer
.
write_audio_chunk
(
0
,
audio
)
dst
=
io
.
BytesIO
()
write_audio
(
dst
,
198_000
)
out0_size
=
dst
.
tell
()
dst
=
io
.
BytesIO
()
write_audio
(
dst
,
320_000
)
out1_size
=
dst
.
tell
()
self
.
assertGreater
(
out1_size
,
out0_size
)
def
test_filter_graph_audio
(
self
):
"""Can apply additional effect with filter graph"""
sample_rate
=
8000
num_channels
=
2
ext
=
"wav"
filename
=
f
"test.
{
ext
}
"
original
=
get_audio_chunk
(
"s16"
,
num_channels
=
num_channels
,
sample_rate
=
sample_rate
)
dst
=
self
.
get_temp_path
(
filename
)
w
=
StreamWriter
(
dst
,
format
=
ext
)
w
.
add_audio_stream
(
sample_rate
=
8000
,
num_channels
=
num_channels
,
filter_desc
=
"areverse"
,
format
=
"s16"
)
with
w
.
open
():
w
.
write_audio_chunk
(
0
,
original
)
reader
=
torchaudio
.
io
.
StreamReader
(
src
=
self
.
get_temp_path
(
filename
))
reader
.
add_audio_stream
(
-
1
)
reader
.
process_all_packets
()
(
output
,)
=
reader
.
pop_chunks
()
self
.
assertEqual
(
output
,
original
.
flip
(
0
))
def
test_filter_graph_video
(
self
):
"""Can apply additional effect with filter graph"""
src_rate
=
30
num_frames
,
width
,
height
=
400
,
160
,
90
filter_desc
=
"framestep=2"
enc_rate
=
15
ext
=
"mp4"
filename
=
f
"test.
{
ext
}
"
original
=
torch
.
zeros
((
num_frames
,
3
,
height
,
width
),
dtype
=
torch
.
uint8
)
dst
=
self
.
get_temp_path
(
filename
)
w
=
StreamWriter
(
dst
,
format
=
ext
)
w
.
add_video_stream
(
frame_rate
=
src_rate
,
format
=
"rgb24"
,
height
=
height
,
width
=
width
,
filter_desc
=
filter_desc
,
encoder_format
=
"yuv420p"
,
encoder_frame_rate
=
enc_rate
,
)
with
w
.
open
():
w
.
write_video_chunk
(
0
,
original
)
reader
=
torchaudio
.
io
.
StreamReader
(
src
=
self
.
get_temp_path
(
filename
))
reader
.
add_video_stream
(
-
1
)
reader
.
process_all_packets
()
(
output
,)
=
reader
.
pop_chunks
()
self
.
assertEqual
(
output
.
shape
,
[
num_frames
//
2
,
3
,
height
,
width
])
@
parameterized
.
expand
(
[
(
"wav"
,
"pcm_s16le"
,
8000
,
16000
,
1
,
2
),
(
"wav"
,
"pcm_s16le"
,
8000
,
16000
,
2
,
1
),
(
"wav"
,
"pcm_s16le"
,
8000
,
16000
,
2
,
4
),
(
"wav"
,
"pcm_s16le"
,
16000
,
8000
,
1
,
2
),
(
"wav"
,
"pcm_s16le"
,
16000
,
8000
,
2
,
1
),
(
"wav"
,
"pcm_s16le"
,
16000
,
8000
,
2
,
4
),
(
"wav"
,
"pcm_f32le"
,
8000
,
16000
,
1
,
2
),
(
"wav"
,
"pcm_f32le"
,
8000
,
16000
,
2
,
1
),
(
"wav"
,
"pcm_f32le"
,
8000
,
16000
,
2
,
4
),
(
"wav"
,
"pcm_f32le"
,
16000
,
8000
,
1
,
2
),
(
"wav"
,
"pcm_f32le"
,
16000
,
8000
,
2
,
1
),
(
"wav"
,
"pcm_f32le"
,
16000
,
8000
,
2
,
4
),
(
"ogg"
,
"opus"
,
8000
,
48000
,
1
,
2
),
(
"ogg"
,
"opus"
,
8000
,
48000
,
2
,
1
),
(
"ogg"
,
"flac"
,
8000
,
41000
,
1
,
2
),
(
"ogg"
,
"flac"
,
8000
,
41000
,
2
,
1
),
(
"ogg"
,
"vorbis"
,
16000
,
8000
,
1
,
2
),
(
"ogg"
,
"vorbis"
,
16000
,
8000
,
4
,
2
),
]
)
def
test_change_audio_encoder_spec
(
self
,
ext
,
encoder
,
src_sr
,
enc_sr
,
src_num_channels
,
enc_num_channels
):
"""Can change sample rate and channels on-the-fly"""
filename
=
f
"test.
{
ext
}
"
original
=
get_sinusoid
(
sample_rate
=
src_sr
,
n_channels
=
src_num_channels
,
channels_first
=
False
,
duration
=
0.1
)
dst
=
self
.
get_temp_path
(
filename
)
w
=
StreamWriter
(
dst
,
format
=
ext
)
w
.
add_audio_stream
(
sample_rate
=
src_sr
,
format
=
"flt"
,
num_channels
=
src_num_channels
,
encoder
=
encoder
,
encoder_sample_rate
=
enc_sr
,
encoder_num_channels
=
enc_num_channels
,
)
with
w
.
open
():
w
.
write_audio_chunk
(
0
,
original
)
# check
reader
=
torchaudio
.
io
.
StreamReader
(
src
=
self
.
get_temp_path
(
filename
))
i
=
reader
.
get_src_stream_info
(
0
)
self
.
assertEqual
(
i
.
sample_rate
,
enc_sr
)
self
.
assertEqual
(
i
.
num_channels
,
enc_num_channels
)
@
parameterized
.
expand
(
[
# opus only supports 48kHz
(
"ogg"
,
"opus"
,
8000
,
48000
,
1
,
1
),
(
"ogg"
,
"opus"
,
16000
,
48000
,
2
,
2
),
# vorbis only supports 2 channels
(
"ogg"
,
"vorbis"
,
16000
,
16000
,
1
,
2
),
(
"ogg"
,
"vorbis"
,
16000
,
16000
,
2
,
2
),
(
"ogg"
,
"vorbis"
,
16000
,
16000
,
4
,
2
),
]
)
def
test_change_encoder_spec_default
(
self
,
ext
,
encoder
,
src_sr
,
expected_sr
,
src_num_channels
,
expected_num_channels
):
"""If input rate/channels are not supported, encoder picks supported one automatically."""
filename
=
f
"test.
{
ext
}
"
original
=
get_sinusoid
(
sample_rate
=
src_sr
,
n_channels
=
src_num_channels
,
channels_first
=
False
,
duration
=
0.1
)
dst
=
self
.
get_temp_path
(
filename
)
w
=
StreamWriter
(
dst
,
format
=
ext
)
w
.
add_audio_stream
(
sample_rate
=
src_sr
,
format
=
"flt"
,
num_channels
=
src_num_channels
,
encoder
=
encoder
,
)
with
w
.
open
():
w
.
write_audio_chunk
(
0
,
original
)
# check
reader
=
torchaudio
.
io
.
StreamReader
(
src
=
self
.
get_temp_path
(
filename
))
i
=
reader
.
get_src_stream_info
(
0
)
self
.
assertEqual
(
i
.
sample_rate
,
expected_sr
)
self
.
assertEqual
(
i
.
num_channels
,
expected_num_channels
)
@
parameterized
.
expand
(
[
(
"mp4"
,
None
,
10
,
30
,
(
100
,
160
),
(
200
,
320
)),
(
"mp4"
,
None
,
10
,
30
,
(
100
,
160
),
(
50
,
80
)),
(
"mp4"
,
None
,
30
,
10
,
(
100
,
160
),
(
200
,
320
)),
(
"mp4"
,
None
,
30
,
10
,
(
100
,
160
),
(
50
,
80
)),
]
)
def
test_change_video_encoder_spec
(
self
,
ext
,
encoder
,
src_rate
,
enc_rate
,
src_size
,
enc_size
):
"""Can change the frame rate and image size on-the-fly"""
width
,
height
=
src_size
enc_width
,
enc_height
=
enc_size
ext
=
"mp4"
filename
=
f
"test.
{
ext
}
"
num_frames
=
256
original
=
torch
.
zeros
((
num_frames
,
3
,
height
,
width
),
dtype
=
torch
.
uint8
)
dst
=
self
.
get_temp_path
(
filename
)
w
=
StreamWriter
(
dst
,
format
=
ext
)
w
.
add_video_stream
(
frame_rate
=
src_rate
,
format
=
"rgb24"
,
height
=
height
,
width
=
width
,
encoder_format
=
"yuv420p"
,
encoder_frame_rate
=
enc_rate
,
encoder_width
=
enc_width
,
encoder_height
=
enc_height
,
)
with
w
.
open
():
w
.
write_video_chunk
(
0
,
original
)
# check
reader
=
torchaudio
.
io
.
StreamReader
(
src
=
self
.
get_temp_path
(
filename
))
i
=
reader
.
get_src_stream_info
(
0
)
self
.
assertEqual
(
i
.
frame_rate
,
enc_rate
)
self
.
assertEqual
(
i
.
width
,
enc_width
)
self
.
assertEqual
(
i
.
height
,
enc_height
)
test/torchaudio_unittest/models/decoder/ctc_decoder_test.py
View file @
ffeba11a
...
...
@@ -169,3 +169,19 @@ class CTCDecoderTest(TempDirMixin, TorchaudioTestCase):
expected_tokens
=
[
"|"
,
"f"
,
"|"
,
"o"
,
"a"
]
self
.
assertEqual
(
tokens
,
expected_tokens
)
def
test_lm_lifecycle
(
self
):
"""Passing lm without assiging it to a vaiable won't cause runtime error
https://github.com/pytorch/audio/issues/3218
"""
from
torchaudio.models.decoder
import
ctc_decoder
from
.ctc_decoder_utils
import
CustomZeroLM
decoder
=
ctc_decoder
(
lexicon
=
get_asset_path
(
"decoder/lexicon.txt"
),
tokens
=
get_asset_path
(
"decoder/tokens.txt"
),
lm
=
CustomZeroLM
(),
)
decoder
(
torch
.
zeros
((
1
,
3
,
NUM_TOKENS
),
dtype
=
torch
.
float32
))
test/torchaudio_unittest/models/decoder/cuda_ctc_decoder_test.py
0 → 100644
View file @
ffeba11a
import
torch
from
torchaudio_unittest.common_utils
import
(
get_asset_path
,
skipIfNoCuCtcDecoder
,
skipIfNoCuda
,
TempDirMixin
,
TorchaudioTestCase
,
)
NUM_TOKENS
=
7
@
skipIfNoCuda
@
skipIfNoCuCtcDecoder
class
CUCTCDecoderTest
(
TempDirMixin
,
TorchaudioTestCase
):
def
_get_decoder
(
self
,
tokens
=
None
,
**
kwargs
):
from
torchaudio.models.decoder
import
cuda_ctc_decoder
if
tokens
is
None
:
tokens
=
get_asset_path
(
"decoder/tokens.txt"
)
return
cuda_ctc_decoder
(
tokens
=
tokens
,
beam_size
=
5
,
**
kwargs
,
)
def
_get_emissions
(
self
):
B
,
T
,
N
=
4
,
15
,
NUM_TOKENS
emissions
=
torch
.
rand
(
B
,
T
,
N
).
cuda
()
emissions
=
torch
.
nn
.
functional
.
log_softmax
(
emissions
,
-
1
)
return
emissions
def
test_construct_basic_decoder_path
(
self
):
tokens_path
=
get_asset_path
(
"decoder/tokens.txt"
)
self
.
_get_decoder
(
tokens
=
tokens_path
)
def
test_construct_basic_decoder_tokens
(
self
):
tokens
=
[
"-"
,
"|"
,
"f"
,
"o"
,
"b"
,
"a"
,
"r"
]
self
.
_get_decoder
(
tokens
=
tokens
)
def
test_shape
(
self
):
log_probs
=
self
.
_get_emissions
()
encoder_out_lens
=
torch
.
tensor
([
15
,
14
,
13
,
12
],
dtype
=
torch
.
int32
).
cuda
()
decoder
=
self
.
_get_decoder
()
results
=
decoder
(
log_probs
,
encoder_out_lens
)
self
.
assertEqual
(
len
(
results
),
log_probs
.
shape
[
0
])
test/torchaudio_unittest/models/rnnt_decoder/rnnt_decoder_test_impl.py
View file @
ffeba11a
...
...
@@ -99,7 +99,7 @@ class RNNTBeamSearchTestImpl(TestBaseMixin):
self
.
assertEqual
(
res
,
scripted_res
)
state
=
res
[
1
]
hypo
=
res
[
0
]
[
0
]
hypo
=
res
[
0
]
scripted_state
=
scripted_res
[
1
]
scripted_hypo
=
scripted_res
[
0
]
[
0
]
scripted_hypo
=
scripted_res
[
0
]
test/torchaudio_unittest/models/squim/__init__.py
0 → 100644
View file @
ffeba11a
test/torchaudio_unittest/models/squim/squim_test.py
0 → 100644
View file @
ffeba11a
import
torch
from
parameterized
import
parameterized
from
torchaudio.models
import
squim_objective_base
,
squim_subjective_base
from
torchaudio_unittest.common_utils
import
skipIfNoCuda
,
torch_script
,
TorchaudioTestCase
class
TestSquimObjective
(
TorchaudioTestCase
):
def
_smoke_test_objective
(
self
,
model
,
device
,
dtype
):
model
=
model
.
to
(
device
=
device
,
dtype
=
dtype
)
model
=
model
.
eval
()
batch_size
,
num_frames
=
3
,
16000
waveforms
=
torch
.
randn
(
batch_size
,
num_frames
,
device
=
device
,
dtype
=
dtype
)
model
(
waveforms
)
@
parameterized
.
expand
([(
torch
.
float32
,),
(
torch
.
float64
,)])
def
test_cpu_smoke_test
(
self
,
dtype
):
model
=
squim_objective_base
()
self
.
_smoke_test_objective
(
model
,
torch
.
device
(
"cpu"
),
dtype
)
@
parameterized
.
expand
([(
torch
.
float32
,),
(
torch
.
float64
,)])
@
skipIfNoCuda
def
test_cuda_smoke_test
(
self
,
dtype
):
model
=
squim_objective_base
()
self
.
_smoke_test_objective
(
model
,
torch
.
device
(
"cuda"
),
dtype
)
def
test_batch_consistency
(
self
):
model
=
squim_objective_base
()
model
.
eval
()
batch_size
,
num_frames
=
3
,
16000
waveforms
=
torch
.
randn
(
batch_size
,
num_frames
)
ref_scores
=
model
(
waveforms
)
hyp_scores
=
[
torch
.
zeros
(
batch_size
),
torch
.
zeros
(
batch_size
),
torch
.
zeros
(
batch_size
)]
for
i
in
range
(
batch_size
):
scores
=
model
(
waveforms
[
i
:
i
+
1
])
for
j
in
range
(
3
):
hyp_scores
[
j
][
i
]
=
scores
[
j
]
self
.
assertEqual
(
len
(
hyp_scores
),
len
(
ref_scores
))
for
i
in
range
(
len
(
ref_scores
)):
self
.
assertEqual
(
hyp_scores
[
i
],
ref_scores
[
i
])
def
test_torchscript_consistency
(
self
):
model
=
squim_objective_base
()
model
.
eval
()
batch_size
,
num_frames
=
3
,
16000
waveforms
=
torch
.
randn
(
batch_size
,
num_frames
)
ref_scores
=
model
(
waveforms
)
scripted
=
torch_script
(
model
)
hyp_scores
=
scripted
(
waveforms
)
self
.
assertEqual
(
len
(
hyp_scores
),
len
(
ref_scores
))
for
i
in
range
(
len
(
ref_scores
)):
self
.
assertEqual
(
hyp_scores
[
i
],
ref_scores
[
i
])
class
TestSquimSubjective
(
TorchaudioTestCase
):
def
_smoke_test_subjective
(
self
,
model
,
device
,
dtype
):
model
=
model
.
to
(
device
=
device
,
dtype
=
dtype
)
model
=
model
.
eval
()
batch_size
,
num_frames
=
3
,
16000
waveforms
=
torch
.
randn
(
batch_size
,
num_frames
,
device
=
device
,
dtype
=
dtype
)
reference
=
torch
.
randn
(
batch_size
,
num_frames
,
device
=
device
,
dtype
=
dtype
)
model
(
waveforms
,
reference
)
@
parameterized
.
expand
([(
torch
.
float32
,),
(
torch
.
float64
,)])
def
test_cpu_smoke_test
(
self
,
dtype
):
model
=
squim_subjective_base
()
self
.
_smoke_test_subjective
(
model
,
torch
.
device
(
"cpu"
),
dtype
)
@
parameterized
.
expand
([(
torch
.
float32
,),
(
torch
.
float64
,)])
@
skipIfNoCuda
def
test_cuda_smoke_test
(
self
,
dtype
):
model
=
squim_subjective_base
()
self
.
_smoke_test_subjective
(
model
,
torch
.
device
(
"cuda"
),
dtype
)
def
test_batch_consistency
(
self
):
model
=
squim_subjective_base
()
model
.
eval
()
batch_size
,
num_frames
=
3
,
16000
waveforms
=
torch
.
randn
(
batch_size
,
num_frames
)
reference
=
torch
.
randn
(
batch_size
,
num_frames
)
ref_scores
=
model
(
waveforms
,
reference
)
hyp_scores
=
[]
for
i
in
range
(
batch_size
):
scores
=
model
(
waveforms
[
i
:
i
+
1
],
reference
[
i
:
i
+
1
])
hyp_scores
.
append
(
scores
)
hyp_scores
=
torch
.
tensor
(
hyp_scores
)
self
.
assertEqual
(
hyp_scores
,
ref_scores
)
def
test_torchscript_consistency
(
self
):
model
=
squim_subjective_base
()
model
.
eval
()
batch_size
,
num_frames
=
3
,
16000
waveforms
=
torch
.
randn
(
batch_size
,
num_frames
)
reference
=
torch
.
randn
(
batch_size
,
num_frames
)
ref_scores
=
model
(
waveforms
,
reference
)
scripted
=
torch_script
(
model
)
hyp_scores
=
scripted
(
waveforms
,
reference
)
self
.
assertEqual
(
hyp_scores
,
ref_scores
)
test/torchaudio_unittest/models/tacotron2/model_test_impl.py
View file @
ffeba11a
...
...
@@ -42,7 +42,7 @@ class TorchscriptConsistencyMixin(TestBaseMixin):
class
Tacotron2EncoderTests
(
TorchscriptConsistencyMixin
):
@
skipIfPy310
#
@skipIfPy310
def
test_tacotron2_torchscript_consistency
(
self
):
r
"""Validate the torchscript consistency of a Encoder."""
n_batch
,
n_seq
,
encoder_embedding_dim
=
16
,
64
,
512
...
...
@@ -266,7 +266,7 @@ class Tacotron2Tests(TorchscriptConsistencyMixin):
(
16
,),
]
)
@
skipIfPy310
#
@skipIfPy310
def
test_tacotron2_torchscript_consistency
(
self
,
n_batch
):
r
"""Validate the torchscript consistency of a Tacotron2."""
n_mels
=
80
...
...
@@ -335,7 +335,7 @@ class Tacotron2Tests(TorchscriptConsistencyMixin):
(
16
,),
]
)
@
skipIfPy310
#
@skipIfPy310
def
test_tacotron2_inference_torchscript_consistency
(
self
,
n_batch
):
r
"""Validate the torchscript consistency of Tacotron2 inference function."""
n_mels
=
40
...
...
test/torchaudio_unittest/models/wav2vec2/fairseq_integration_test.py
View file @
ffeba11a
...
...
@@ -9,9 +9,12 @@ from torchaudio.models.wav2vec2 import (
wav2vec2_base
,
wav2vec2_large
,
wav2vec2_large_lv60k
,
wav2vec2_xlsr_1b
,
wav2vec2_xlsr_2b
,
wav2vec2_xlsr_300m
,
)
from
torchaudio.models.wav2vec2.utils
import
import_fairseq_model
from
torchaudio_unittest.common_utils
import
get_asset_path
,
skipIfNoModule
,
TorchaudioTestCase
from
torchaudio_unittest.common_utils
import
get_asset_path
,
skipIfCudaSmallMemory
,
skipIfNoModule
,
TorchaudioTestCase
def
_load_config
(
*
paths
):
...
...
@@ -31,6 +34,9 @@ WAV2VEC2_XLSR_53_56K = _load_config("xlsr_53_56k")
HUBERT_BASE
=
_load_config
(
"hubert_base_ls960"
)
HUBERT_LARGE_LL60K
=
_load_config
(
"hubert_large_ll60k"
)
HUBERT_XLARGE_LL60K
=
_load_config
(
"hubert_xtralarge_ll60k"
)
WAV2VEC2_XLSR_300M
=
_load_config
(
"xlsr_300m"
)
WAV2VEC2_XLSR_1B
=
_load_config
(
"xlsr_1b"
)
WAV2VEC2_XLSR_2B
=
_load_config
(
"xlsr_2b"
)
# Finetuning models
WAV2VEC2_BASE_960H
=
_load_config
(
"wav2vec_small_960h"
)
WAV2VEC2_LARGE_960H
=
_load_config
(
"wav2vec_large_960h"
)
...
...
@@ -50,6 +56,14 @@ WAV2VEC2_PRETRAINING_CONFIGS = parameterized.expand(
],
name_func
=
_name_func
,
)
XLSR_PRETRAINING_CONFIGS
=
parameterized
.
expand
(
[
(
WAV2VEC2_XLSR_300M
,
wav2vec2_xlsr_300m
),
(
WAV2VEC2_XLSR_1B
,
wav2vec2_xlsr_1b
),
(
WAV2VEC2_XLSR_2B
,
wav2vec2_xlsr_2b
),
],
name_func
=
_name_func
,
)
HUBERT_PRETRAINING_CONFIGS
=
parameterized
.
expand
(
[
(
HUBERT_BASE
,
hubert_base
),
...
...
@@ -134,7 +148,24 @@ class TestFairseqIntegration(TorchaudioTestCase):
hyp
,
_
=
imported
.
extract_features
(
x
)
refs
=
original
.
extract_features
(
x
,
padding_mask
=
torch
.
zeros_like
(
x
),
layer
=-
1
)
for
i
,
(
ref
,
_
)
in
enumerate
(
refs
[
"layer_results"
]):
self
.
assertEqual
(
hyp
[
i
],
ref
.
transpose
(
0
,
1
))
self
.
assertEqual
(
hyp
[
i
],
ref
.
transpose
(
0
,
1
),
atol
=
1.5e-5
,
rtol
=
1.3e-6
)
@
XLSR_PRETRAINING_CONFIGS
@
skipIfCudaSmallMemory
def
test_import_xlsr_pretraining_model
(
self
,
config
,
factory_func
):
"""XLS-R pretraining models from fairseq can be imported and yields the same results"""
batch_size
,
num_frames
=
3
,
1024
original
=
self
.
_get_model
(
config
).
eval
()
imported
=
import_fairseq_model
(
original
).
eval
()
x
=
torch
.
randn
(
batch_size
,
num_frames
)
hyp
,
_
=
imported
.
extract_features
(
x
)
refs
=
original
.
extract_features
(
x
,
padding_mask
=
torch
.
zeros_like
(
x
),
layer
=-
1
)
for
i
,
(
ref
,
_
)
in
enumerate
(
refs
[
"layer_results"
]):
# There is one element whose difference is over 1e-5 in wav2vec2_xlsr_1b and wav2vec2_xlsr_2b.
atol
=
1.0e-05
if
factory_func
is
wav2vec2_xlsr_300m
else
1e-4
self
.
assertEqual
(
hyp
[
i
],
ref
.
transpose
(
0
,
1
),
atol
=
atol
,
rtol
=
1.3e-6
)
@
HUBERT_PRETRAINING_CONFIGS
def
test_import_hubert_pretraining_model
(
self
,
config
,
factory_func
):
...
...
@@ -150,15 +181,13 @@ class TestFairseqIntegration(TorchaudioTestCase):
# check the last layer
ref
,
_
=
original
.
extract_features
(
x
,
padding_mask
=
mask
,
output_layer
=
len
(
original
.
encoder
.
layers
))
atol
=
3.0e-05
if
factory_func
is
hubert_xlarge
else
1.0e-5
self
.
assertEqual
(
hyp
[
-
1
],
ref
,
atol
=
atol
,
rtol
=
1.3e-6
)
self
.
assertEqual
(
hyp
[
-
1
],
ref
,
atol
=
3.0e-5
,
rtol
=
1.3e-6
)
# check the first layer
ref
,
_
=
original
.
extract_features
(
x
,
padding_mask
=
mask
,
output_layer
=
1
)
self
.
assertEqual
(
hyp
[
0
],
ref
)
@
ALL_PRETRAINING_CONFIGS
def
test_recreate_pretraining_model
(
self
,
config
,
factory_func
):
def
_test_recreate_pretraining_model
(
self
,
config
,
factory_func
):
"""Imported pretraining models can be recreated via a factory function without fairseq."""
batch_size
,
num_frames
=
3
,
1024
...
...
@@ -188,6 +217,15 @@ class TestFairseqIntegration(TorchaudioTestCase):
self
.
assertEqual
(
ref
,
hyp
)
self
.
assertEqual
(
ref_lengths
,
hyp_lengths
)
@
ALL_PRETRAINING_CONFIGS
def
test_wav2vec2_recreate_pretraining_model
(
self
,
config
,
factory_func
):
self
.
_test_recreate_pretraining_model
(
config
,
factory_func
)
@
XLSR_PRETRAINING_CONFIGS
@
skipIfCudaSmallMemory
def
test_xlsr_recreate_pretraining_model
(
self
,
config
,
factory_func
):
self
.
_test_recreate_pretraining_model
(
config
,
factory_func
)
@
FINETUNING_CONFIGS
def
test_import_finetuning_model
(
self
,
config
,
_
):
"""Fintuned wav2vec2 models from fairseq can be imported and yields the same results"""
...
...
test/torchaudio_unittest/models/wav2vec2/huggingface_intergration_test.py
View file @
ffeba11a
import
json
import
unittest
import
torch
from
parameterized
import
parameterized
from
torchaudio.models.wav2vec2
import
wav2vec2_base
,
wav2vec2_large
,
wav2vec2_large_lv60k
from
torchaudio.models.wav2vec2
import
(
wav2vec2_base
,
wav2vec2_large
,
wav2vec2_large_lv60k
,
wav2vec2_xlsr_1b
,
wav2vec2_xlsr_2b
,
wav2vec2_xlsr_300m
,
wavlm_base
,
wavlm_large
,
)
from
torchaudio.models.wav2vec2.utils
import
import_huggingface_model
from
torchaudio_unittest.common_utils
import
get_asset_path
,
skipIfNoModule
,
TorchaudioTestCase
from
torchaudio_unittest.common_utils
import
(
get_asset_path
,
skipIfCudaSmallMemory
,
skipIfNoModule
,
TorchaudioTestCase
,
zip_equal
,
)
def
_load_config
(
*
paths
):
...
...
@@ -22,6 +38,11 @@ HF_LARGE = _load_config("wav2vec2-large")
HF_LARGE_LV60
=
_load_config
(
"wav2vec2-large-lv60"
)
HF_LARGE_XLSR_53
=
_load_config
(
"wav2vec2-large-xlsr-53"
)
HF_BASE_10K_VOXPOPULI
=
_load_config
(
"wav2vec2-base-10k-voxpopuli"
)
HF_BASE_WAVLM
=
_load_config
(
"wavlm-base"
)
HF_LARGE_WAVLM
=
_load_config
(
"wavlm-large"
)
HF_XLSR_300M
=
_load_config
(
"wav2vec2-xls-r-300m"
)
HF_XLSR_1B
=
_load_config
(
"wav2vec2-xls-r-1b"
)
HF_XLSR_2B
=
_load_config
(
"wav2vec2-xls-r-2b"
)
# Finetuned
HF_BASE_960H
=
_load_config
(
"wav2vec2-base-960h"
)
HF_LARGE_960H
=
_load_config
(
"wav2vec2-large-960h"
)
...
...
@@ -40,6 +61,14 @@ PRETRAIN_CONFIGS = parameterized.expand(
],
name_func
=
_name_func
,
)
XLSR_PRETRAIN_CONFIGS
=
parameterized
.
expand
(
[
(
HF_XLSR_300M
,
wav2vec2_xlsr_300m
),
(
HF_XLSR_1B
,
wav2vec2_xlsr_1b
),
(
HF_XLSR_2B
,
wav2vec2_xlsr_2b
),
],
name_func
=
_name_func
,
)
FINETUNE_CONFIGS
=
parameterized
.
expand
(
[
(
HF_BASE_960H
,
wav2vec2_base
),
...
...
@@ -50,8 +79,16 @@ FINETUNE_CONFIGS = parameterized.expand(
],
name_func
=
_name_func
,
)
WAVLM_CONFIGS
=
parameterized
.
expand
(
[
(
HF_BASE_WAVLM
,
wavlm_base
),
(
HF_LARGE_WAVLM
,
wavlm_large
),
],
name_func
=
_name_func
,
)
@
unittest
.
skip
(
"transformers v4.30 seems to break the weight format. See https://github.com/pytorch/audio/issues/3430"
)
@
skipIfNoModule
(
"transformers"
)
class
TestHFIntegration
(
TorchaudioTestCase
):
"""Test the process of importing the models from Hugging Face Transformers
...
...
@@ -68,12 +105,14 @@ class TestHFIntegration(TorchaudioTestCase):
# However, somehow, once "transformers" is imported, `is_module_available`
# starts to fail. Therefore, we defer importing "transformers" until
# the actual tests are started.
from
transformers
.models.wav2vec2
import
Wav2Vec2Config
,
Wav2Vec2ForCTC
,
Wav2Vec2Model
from
transformers
import
Wav2Vec2Config
,
Wav2Vec2ForCTC
,
Wav2Vec2Model
,
WavLMConfig
,
WavLMModel
if
config
[
"architectures"
]
==
[
"Wav2Vec2Model"
]:
return
Wav2Vec2Model
(
Wav2Vec2Config
(
**
config
))
if
config
[
"architectures"
]
==
[
"Wav2Vec2ForCTC"
]:
return
Wav2Vec2ForCTC
(
Wav2Vec2Config
(
**
config
))
if
config
[
"architectures"
]
==
[
"WavLMModel"
]:
return
WavLMModel
(
WavLMConfig
(
**
config
))
raise
ValueError
(
f
'Unexpected arch:
{
config
[
"architectures"
]
}
'
)
def
_test_import_pretrain
(
self
,
original
,
imported
,
config
):
...
...
@@ -97,9 +136,8 @@ class TestHFIntegration(TorchaudioTestCase):
b
,
l
,
e
=
16
,
3
,
config
[
"hidden_size"
]
x
=
torch
.
randn
(
b
,
l
,
e
)
mask
=
torch
.
randn
(
b
,
1
,
l
,
l
)
(
ref
,)
=
original_
(
x
,
attention_mask
=
mask
,
output_attentions
=
False
)
hyp
=
imported_
(
x
,
mask
)
hyp
,
_
=
imported_
(
x
,
mask
)
# Ignore returned position_bias, which is always None for Wav2Vec2 and HuBERT
self
.
assertEqual
(
ref
,
hyp
)
# The whole Encoder Transformer
b
,
l
,
e
=
16
,
3
,
config
[
"hidden_size"
]
...
...
@@ -115,11 +153,6 @@ class TestHFIntegration(TorchaudioTestCase):
hyp
=
imported
.
aux
(
x
)
self
.
assertEqual
(
ref
,
hyp
)
# The whole model without mask
x
=
torch
.
randn
(
3
,
1024
)
ref
=
original
(
x
).
logits
hyp
,
_
=
imported
(
x
)
self
.
assertEqual
(
ref
,
hyp
)
# The whole model without mask
batch_size
,
num_frames
=
3
,
1024
x
=
torch
.
randn
(
batch_size
,
num_frames
)
ref
=
original
(
x
).
logits
...
...
@@ -151,6 +184,14 @@ class TestHFIntegration(TorchaudioTestCase):
imported
=
import_huggingface_model
(
original
).
eval
()
self
.
_test_import_pretrain
(
original
,
imported
,
config
)
@
XLSR_PRETRAIN_CONFIGS
@
skipIfCudaSmallMemory
def
test_import_xlsr_pretrain
(
self
,
config
,
_
):
"""XLS-R models from HF transformers can be imported and yields the same results"""
original
=
self
.
_get_model
(
config
).
eval
()
imported
=
import_huggingface_model
(
original
).
eval
()
self
.
_test_import_pretrain
(
original
,
imported
,
config
)
@
FINETUNE_CONFIGS
def
test_import_finetune
(
self
,
config
,
_
):
"""wav2vec2 models from HF transformers can be imported and yields the same results"""
...
...
@@ -159,6 +200,51 @@ class TestHFIntegration(TorchaudioTestCase):
self
.
_test_import_pretrain
(
original
.
wav2vec2
,
imported
,
config
)
self
.
_test_import_finetune
(
original
,
imported
,
config
)
@
WAVLM_CONFIGS
def
test_import_pretrain_wavlm
(
self
,
config
,
_
):
"""WavLM models from HF transformers can be imported and yield the same results"""
original
=
self
.
_get_model
(
config
).
eval
()
imported
=
import_huggingface_model
(
original
).
eval
()
# FeatureExtractor
x
=
torch
.
randn
(
3
,
1024
)
ref
=
original
.
feature_extractor
(
x
).
transpose
(
1
,
2
)
hyp
,
_
=
imported
.
feature_extractor
(
x
,
None
)
self
.
assertEqual
(
ref
,
hyp
)
# Feature projection
x
=
torch
.
randn
(
3
,
10
,
config
[
"conv_dim"
][
-
1
])
ref
=
original
.
feature_projection
(
x
)[
0
]
hyp
=
imported
.
encoder
.
feature_projection
(
x
)
self
.
assertEqual
(
ref
,
hyp
)
# Convolutional Positional Encoder
x
=
torch
.
randn
(
3
,
256
,
config
[
"hidden_size"
])
ref
=
original
.
encoder
.
pos_conv_embed
(
x
)
hyp
=
imported
.
encoder
.
transformer
.
pos_conv_embed
(
x
)
self
.
assertEqual
(
ref
,
hyp
)
position_bias
=
None
position_bias_imp
=
None
assert
len
(
original
.
encoder
.
layers
)
>
0
for
original_
,
imported_
in
zip_equal
(
original
.
encoder
.
layers
,
imported
.
encoder
.
transformer
.
layers
):
b
,
l
,
e
=
16
,
3
,
config
[
"hidden_size"
]
x
=
torch
.
randn
(
b
,
l
,
e
)
mask
=
torch
.
randn
(
b
,
l
)
>
0.5
# HF WaveLM model expects the mask to be binary
# HF WaveLM model (original_) takes in "attention mask" but actually uses it as key padding mask:
# https://github.com/huggingface/transformers/blob/b047472650cba259621549ac27b18fd2066ce18e/src/transformers/models/wavlm/modeling_wavlm.py#L495
ref
,
position_bias
=
original_
(
x
,
attention_mask
=
mask
,
output_attentions
=
False
,
position_bias
=
position_bias
)
hyp
,
position_bias_imp
=
imported_
(
x
,
key_padding_mask
=
mask
.
ne
(
1
),
position_bias
=
position_bias_imp
)
# Masked-out elements are undefined in the output
ref_filled
=
ref
.
masked_fill
(
~
mask
.
unsqueeze
(
2
),
0
)
hyp_filled
=
hyp
.
masked_fill
(
~
mask
.
unsqueeze
(
2
),
0
)
self
.
assertEqual
(
ref_filled
,
hyp_filled
)
# The whole Encoder Transformer
b
,
l
,
e
=
16
,
3
,
config
[
"hidden_size"
]
x
=
torch
.
randn
(
b
,
l
,
e
)
ref
=
original
.
encoder
(
x
).
last_hidden_state
hyp
=
imported
.
encoder
.
transformer
(
x
)
self
.
assertEqual
(
ref
,
hyp
)
def
_test_recreate
(
self
,
imported
,
reloaded
,
config
):
# FeatureExtractor
x
=
torch
.
randn
(
3
,
1024
)
...
...
@@ -221,3 +307,50 @@ class TestHFIntegration(TorchaudioTestCase):
reloaded
.
load_state_dict
(
imported
.
state_dict
())
reloaded
.
eval
()
self
.
_test_recreate
(
imported
,
reloaded
,
config
)
@
WAVLM_CONFIGS
def
test_recreate_wavlm
(
self
,
config
,
factory_func
):
"""Imported models can be recreated via a factory function without Hugging Face transformers."""
imported
=
import_huggingface_model
(
self
.
_get_model
(
config
)).
eval
()
reloaded
=
factory_func
()
reloaded
.
load_state_dict
(
imported
.
state_dict
())
reloaded
.
eval
()
# FeatureExtractor
x
=
torch
.
randn
(
3
,
1024
)
ref
,
_
=
imported
.
feature_extractor
(
x
,
None
)
hyp
,
_
=
reloaded
.
feature_extractor
(
x
,
None
)
self
.
assertEqual
(
ref
,
hyp
)
# Feature projection
x
=
torch
.
randn
(
3
,
10
,
config
[
"conv_dim"
][
-
1
])
ref
=
imported
.
encoder
.
feature_projection
(
x
)
hyp
=
reloaded
.
encoder
.
feature_projection
(
x
)
self
.
assertEqual
(
ref
,
hyp
)
# Convolutional Positional Encoder
x
=
torch
.
randn
(
3
,
256
,
config
[
"hidden_size"
])
ref
=
imported
.
encoder
.
transformer
.
pos_conv_embed
(
x
)
hyp
=
reloaded
.
encoder
.
transformer
.
pos_conv_embed
(
x
)
self
.
assertEqual
(
ref
,
hyp
)
# Encoder Transformer Layer
position_bias_ref
=
None
position_bias_hyp
=
None
for
imported_
,
reloaded_
in
zip
(
imported
.
encoder
.
transformer
.
layers
,
reloaded
.
encoder
.
transformer
.
layers
):
b
,
l
,
e
=
16
,
3
,
config
[
"hidden_size"
]
x
=
torch
.
randn
(
b
,
l
,
e
)
mask
=
torch
.
randn
(
b
,
l
)
>
0.5
# HugginFace WaveLM expects the mask to be binary
ref
,
position_bias_ref
=
imported_
(
x
,
key_padding_mask
=
mask
,
position_bias
=
position_bias_ref
)
hyp
,
position_bias_hyp
=
reloaded_
(
x
,
key_padding_mask
=
mask
,
position_bias
=
position_bias_hyp
)
self
.
assertEqual
(
ref
,
hyp
)
# The whole Encoder Transformer
# TODO: Add mask pattern. Expected mask shapes and values are different.
b
,
l
,
e
=
16
,
3
,
config
[
"hidden_size"
]
x
=
torch
.
randn
(
b
,
l
,
e
)
mask
=
torch
.
randn
(
b
,
1
,
l
,
l
)
ref
=
imported
.
encoder
.
transformer
(
x
)
hyp
=
reloaded
.
encoder
.
transformer
(
x
)
self
.
assertEqual
(
ref
,
hyp
)
# The whole model
x
=
torch
.
randn
(
3
,
1024
)
ref
,
_
=
imported
(
x
)
hyp
,
_
=
reloaded
(
x
)
self
.
assertEqual
(
ref
,
hyp
)
test/torchaudio_unittest/models/wav2vec2/model_test.py
View file @
ffeba11a
...
...
@@ -15,6 +15,8 @@ from torchaudio.models.wav2vec2 import (
wav2vec2_base
,
wav2vec2_large
,
wav2vec2_large_lv60k
,
wavlm_base
,
wavlm_large
,
)
from
torchaudio_unittest.common_utils
import
skipIfNoCuda
,
skipIfNoQengine
,
torch_script
,
TorchaudioTestCase
...
...
@@ -41,6 +43,14 @@ factory_funcs = parameterized.expand(
name_func
=
_name_func
,
)
factory_funcs_wavlm
=
parameterized
.
expand
(
[
(
wavlm_base
,),
(
wavlm_large
,),
],
name_func
=
_name_func
,
)
factory_funcs_hubert_pretrain
=
parameterized
.
expand
(
[
(
hubert_pretrain_base
,),
...
...
@@ -278,6 +288,131 @@ class TestWav2Vec2Model(TorchaudioTestCase):
self
.
_test_quantize_torchscript
(
factory_func
(
aux_num_out
=
32
))
class
TestWavLMModel
(
TorchaudioTestCase
):
def
_smoke_test
(
self
,
model
,
device
,
dtype
):
model
=
model
.
to
(
device
=
device
,
dtype
=
dtype
)
model
=
model
.
eval
()
batch_size
,
num_frames
=
3
,
1024
waveforms
=
torch
.
randn
(
batch_size
,
num_frames
,
device
=
device
,
dtype
=
dtype
)
model
(
waveforms
)
@
parameterized
.
expand
([(
torch
.
float32
,),
(
torch
.
float64
,)])
def
test_cpu_smoke_test
(
self
,
dtype
):
model
=
wavlm_base
()
self
.
_smoke_test
(
model
,
torch
.
device
(
"cpu"
),
dtype
)
model
=
wavlm_base
(
aux_num_out
=
32
)
self
.
_smoke_test
(
model
,
torch
.
device
(
"cpu"
),
dtype
)
@
parameterized
.
expand
([(
torch
.
float32
,),
(
torch
.
float64
,)])
@
skipIfNoCuda
def
test_cuda_smoke_test
(
self
,
dtype
):
model
=
wavlm_base
()
self
.
_smoke_test
(
model
,
torch
.
device
(
"cuda"
),
dtype
)
model
=
wavlm_base
(
aux_num_out
=
32
)
self
.
_smoke_test
(
model
,
torch
.
device
(
"cuda"
),
dtype
)
def
_test_batch_consistency
(
self
,
model
):
model
.
eval
()
batch_size
,
max_frames
=
5
,
5
*
1024
waveforms
=
torch
.
randn
(
batch_size
,
max_frames
)
# Batch process
batch_logits
,
_
=
model
(
waveforms
)
# Par-sample process
for
i
in
range
(
batch_size
):
single_logit
,
_
=
model
(
waveforms
[
i
:
i
+
1
])
batch_logit
=
batch_logits
[
i
:
i
+
1
]
# Convert to probability so that it's easier to interpretate the diff
single_prob
=
F
.
softmax
(
single_logit
,
dim
=
2
)
batch_prob
=
F
.
softmax
(
batch_logit
,
dim
=
2
)
# We allow max atol=0.005 -> 0.5%
self
.
assertEqual
(
single_prob
,
batch_prob
,
atol
=
0.005
,
rtol
=
0
)
@
factory_funcs_wavlm
def
test_pretrain_batch_consistency
(
self
,
factory_func
):
"""Results from single process and batched process should be reasonably close"""
self
.
_test_batch_consistency
(
factory_func
())
@
factory_funcs_wavlm
def
test_finetune_batch_consistency
(
self
,
factory_func
):
"""Results from single process and batched process should be reasonably close"""
self
.
_test_batch_consistency
(
factory_func
(
aux_num_out
=
32
))
def
_test_torchscript
(
self
,
model
):
model
.
eval
()
batch_size
,
num_frames
=
3
,
1024
waveforms
=
torch
.
randn
(
batch_size
,
num_frames
)
# Compute results with original model
ref_out
,
ref_len
=
model
(
waveforms
)
# Compute results with scripted model
scripted
=
torch_script
(
model
)
hyp_out
,
hyp_len
=
scripted
(
waveforms
)
self
.
assertEqual
(
hyp_out
,
ref_out
)
self
.
assertEqual
(
hyp_len
,
ref_len
)
@
factory_funcs_wavlm
def
test_pretrain_torchscript
(
self
,
factory_func
):
"""WavLM model should be scriptable"""
self
.
_test_torchscript
(
factory_func
())
@
factory_funcs_wavlm
def
test_finetune_torchscript
(
self
,
factory_func
):
"""WavLM model with a head should be scriptable"""
self
.
_test_torchscript
(
factory_func
(
aux_num_out
=
32
))
def
_test_quantize_smoke_test
(
self
,
model
):
model
.
eval
()
batch_size
,
num_frames
=
3
,
1024
# Remove the weight normalization forward hook
model
.
encoder
.
transformer
.
pos_conv_embed
.
__prepare_scriptable__
()
quantized
=
tq
.
quantize_dynamic
(
model
,
qconfig_spec
=
{
torch
.
nn
.
Linear
},
dtype
=
torch
.
qint8
)
# A lazy way to check that Modules are different
assert
str
(
quantized
)
!=
str
(
model
),
"Dynamic quantization did not modify the module."
waveforms
=
torch
.
randn
(
batch_size
,
num_frames
)
_
,
_
=
quantized
(
waveforms
)
@
factory_funcs_wavlm
@
skipIfNoQengine
def
test_quantize
(
self
,
factory_func
):
"""WavLM should support basic quantization"""
self
.
_test_quantize_smoke_test
(
factory_func
(
aux_num_out
=
32
))
def
_test_quantize_torchscript
(
self
,
model
):
model
.
eval
()
batch_size
,
num_frames
=
3
,
1024
# Remove the weight normalization forward hook
model
.
encoder
.
transformer
.
pos_conv_embed
.
__prepare_scriptable__
()
quantized
=
tq
.
quantize_dynamic
(
model
,
qconfig_spec
=
{
torch
.
nn
.
Linear
},
dtype
=
torch
.
qint8
)
# A lazy way to check that Modules are different
assert
str
(
quantized
)
!=
str
(
model
),
"Dynamic quantization did not modify the module."
waveforms
=
torch
.
randn
(
batch_size
,
num_frames
)
ref_out
,
ref_len
=
quantized
(
waveforms
)
# Script
scripted
=
torch_script
(
quantized
)
hyp_out
,
hyp_len
=
scripted
(
waveforms
)
self
.
assertEqual
(
hyp_out
,
ref_out
)
self
.
assertEqual
(
hyp_len
,
ref_len
)
@
factory_funcs_wavlm
@
skipIfNoQengine
def
test_quantize_torchscript
(
self
,
factory_func
):
"""Quantized WavLM model should be scriptable"""
self
.
_test_quantize_torchscript
(
factory_func
(
aux_num_out
=
32
))
def
_compute_label_frame
(
audio_frame
:
int
)
->
int
:
"""Compute number of frames in the label tensor based on
the number of frames in the audio tensor."""
...
...
test/torchaudio_unittest/sox_effect/dataset_test.py
View file @
ffeba11a
import
os
import
platform
import
sys
import
unittest
from
concurrent.futures
import
ProcessPoolExecutor
from
typing
import
List
,
Tuple
from
unittest
import
skipIf
...
...
@@ -94,8 +95,9 @@ class TestSoxEffectsDataset(TempDirMixin, PytorchTestCase):
loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
32
,
num_workers
=
16
,
num_workers
=
4
,
worker_init_fn
=
init_random_seed
,
multiprocessing_context
=
torch
.
multiprocessing
.
get_context
(
"spawn"
),
)
for
batch
in
loader
:
assert
batch
.
shape
==
(
32
,
2
,
2
*
sample_rate
)
...
...
@@ -115,8 +117,9 @@ class TestSoxEffectsDataset(TempDirMixin, PytorchTestCase):
loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
32
,
num_workers
=
16
,
num_workers
=
4
,
worker_init_fn
=
init_random_seed
,
multiprocessing_context
=
torch
.
multiprocessing
.
get_context
(
"spawn"
),
)
for
batch
in
loader
:
assert
batch
.
shape
==
(
32
,
2
,
2
*
sample_rate
)
...
...
@@ -131,6 +134,7 @@ def speed(path):
return
torchaudio
.
sox_effects
.
apply_effects_tensor
(
wav
,
sample_rate
,
effects
)[
0
]
@
unittest
.
skipIf
(
True
,
"Skipping this test because condition is True"
)
@
skipIfNoSox
class
TestProcessPoolExecutor
(
TempDirMixin
,
PytorchTestCase
):
backend
=
"sox_io"
...
...
test/torchaudio_unittest/sox_effect/smoke_test.py
View file @
ffeba11a
...
...
@@ -54,24 +54,3 @@ class SmokeTest(TempDirMixin, TorchaudioTestCase):
_found
,
_sr
=
sox_effects
.
apply_effects_file
(
input_path
,
effects
,
normalize
=
False
,
channels_first
=
channels_first
)
@
parameterized
.
expand
(
load_params
(
"sox_effect_test_args.jsonl"
),
name_func
=
lambda
f
,
i
,
p
:
f
'
{
f
.
__name__
}
_
{
i
}
_
{
p
.
args
[
0
][
"effects"
][
0
][
0
]
}
'
,
)
def
test_apply_effects_fileobj
(
self
,
args
):
"""`apply_effects_file` should return identical data as sox command"""
dtype
=
"int32"
channels_first
=
True
effects
=
args
[
"effects"
]
num_channels
=
args
.
get
(
"num_channels"
,
2
)
input_sr
=
args
.
get
(
"input_sample_rate"
,
8000
)
input_path
=
self
.
get_temp_path
(
"input.wav"
)
data
=
get_wav_data
(
dtype
,
num_channels
,
channels_first
=
channels_first
)
save_wav
(
input_path
,
data
,
input_sr
,
channels_first
=
channels_first
)
with
open
(
input_path
,
"rb"
)
as
fileobj
:
_found
,
_sr
=
sox_effects
.
apply_effects_file
(
fileobj
,
effects
,
normalize
=
False
,
channels_first
=
channels_first
)
Prev
1
…
9
10
11
12
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment