Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
ffeba11a
Commit
ffeba11a
authored
Sep 02, 2024
by
mayp777
Browse files
UPDATE
parent
29deb085
Changes
337
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1847 additions
and
137 deletions
+1847
-137
test/torchaudio_unittest/functional/librosa_compatibility_cuda_test.py
...io_unittest/functional/librosa_compatibility_cuda_test.py
+3
-1
test/torchaudio_unittest/functional/librosa_compatibility_test_impl.py
...io_unittest/functional/librosa_compatibility_test_impl.py
+2
-2
test/torchaudio_unittest/functional/torchscript_consistency_cuda_test.py
..._unittest/functional/torchscript_consistency_cuda_test.py
+2
-1
test/torchaudio_unittest/functional/torchscript_consistency_impl.py
...audio_unittest/functional/torchscript_consistency_impl.py
+52
-14
test/torchaudio_unittest/io/common.py
test/torchaudio_unittest/io/common.py
+16
-0
test/torchaudio_unittest/io/effector_test.py
test/torchaudio_unittest/io/effector_test.py
+102
-0
test/torchaudio_unittest/io/playback_test.py
test/torchaudio_unittest/io/playback_test.py
+65
-0
test/torchaudio_unittest/io/stream_reader_test.py
test/torchaudio_unittest/io/stream_reader_test.py
+645
-33
test/torchaudio_unittest/io/stream_writer_test.py
test/torchaudio_unittest/io/stream_writer_test.py
+449
-42
test/torchaudio_unittest/models/decoder/ctc_decoder_test.py
test/torchaudio_unittest/models/decoder/ctc_decoder_test.py
+16
-0
test/torchaudio_unittest/models/decoder/cuda_ctc_decoder_test.py
...rchaudio_unittest/models/decoder/cuda_ctc_decoder_test.py
+49
-0
test/torchaudio_unittest/models/rnnt_decoder/rnnt_decoder_test_impl.py
...io_unittest/models/rnnt_decoder/rnnt_decoder_test_impl.py
+2
-2
test/torchaudio_unittest/models/squim/__init__.py
test/torchaudio_unittest/models/squim/__init__.py
+0
-0
test/torchaudio_unittest/models/squim/squim_test.py
test/torchaudio_unittest/models/squim/squim_test.py
+113
-0
test/torchaudio_unittest/models/tacotron2/model_test_impl.py
test/torchaudio_unittest/models/tacotron2/model_test_impl.py
+3
-3
test/torchaudio_unittest/models/wav2vec2/fairseq_integration_test.py
...udio_unittest/models/wav2vec2/fairseq_integration_test.py
+44
-6
test/torchaudio_unittest/models/wav2vec2/huggingface_intergration_test.py
...unittest/models/wav2vec2/huggingface_intergration_test.py
+143
-10
test/torchaudio_unittest/models/wav2vec2/model_test.py
test/torchaudio_unittest/models/wav2vec2/model_test.py
+135
-0
test/torchaudio_unittest/sox_effect/dataset_test.py
test/torchaudio_unittest/sox_effect/dataset_test.py
+6
-2
test/torchaudio_unittest/sox_effect/smoke_test.py
test/torchaudio_unittest/sox_effect/smoke_test.py
+0
-21
No files found.
Too many changes to show.
To preserve performance only
337 of 337+
files are displayed.
Plain diff
Email patch
test/torchaudio_unittest/functional/librosa_compatibility_cuda_test.py
View file @
ffeba11a
from
torchaudio_unittest.common_utils
import
PytorchTestCase
,
skipIfNoCuda
from
torchaudio_unittest.common_utils
import
PytorchTestCase
,
skipIfNoCuda
,
skipIfkmeMark
from
.librosa_compatibility_test_impl
import
Functional
,
FunctionalComplex
from
.librosa_compatibility_test_impl
import
Functional
,
FunctionalComplex
@
skipIfNoCuda
@
skipIfNoCuda
@
skipIfkmeMark
class
TestFunctionalCUDA
(
Functional
,
PytorchTestCase
):
class
TestFunctionalCUDA
(
Functional
,
PytorchTestCase
):
device
=
"cuda"
device
=
"cuda"
@
skipIfNoCuda
@
skipIfNoCuda
@
skipIfkmeMark
class
TestFunctionalComplexCUDA
(
FunctionalComplex
,
PytorchTestCase
):
class
TestFunctionalComplexCUDA
(
FunctionalComplex
,
PytorchTestCase
):
device
=
"cuda"
device
=
"cuda"
test/torchaudio_unittest/functional/librosa_compatibility_test_impl.py
View file @
ffeba11a
import
unittest
import
unittest
from
distutils.version
import
Strict
Version
from
distutils.version
import
Loose
Version
import
torch
import
torch
import
torchaudio.functional
as
F
import
torchaudio.functional
as
F
...
@@ -77,7 +77,7 @@ class Functional(TestBaseMixin):
...
@@ -77,7 +77,7 @@ class Functional(TestBaseMixin):
def
test_create_mel_fb
(
def
test_create_mel_fb
(
self
,
n_mels
=
40
,
sample_rate
=
22050
,
n_fft
=
2048
,
fmin
=
0.0
,
fmax
=
8000.0
,
norm
=
None
,
mel_scale
=
"htk"
self
,
n_mels
=
40
,
sample_rate
=
22050
,
n_fft
=
2048
,
fmin
=
0.0
,
fmax
=
8000.0
,
norm
=
None
,
mel_scale
=
"htk"
):
):
if
norm
==
"slaney"
and
Strict
Version
(
librosa
.
__version__
)
<
Strict
Version
(
"0.7.2"
):
if
norm
==
"slaney"
and
Loose
Version
(
librosa
.
__version__
)
<
Loose
Version
(
"0.7.2"
):
self
.
skipTest
(
"Test is known to fail with older versions of librosa."
)
self
.
skipTest
(
"Test is known to fail with older versions of librosa."
)
if
self
.
device
!=
"cpu"
:
if
self
.
device
!=
"cpu"
:
self
.
skipTest
(
"No need to run this test on CUDA"
)
self
.
skipTest
(
"No need to run this test on CUDA"
)
...
...
test/torchaudio_unittest/functional/torchscript_consistency_cuda_test.py
View file @
ffeba11a
import
torch
import
torch
from
torchaudio_unittest.common_utils
import
PytorchTestCase
,
skipIfNoCuda
from
torchaudio_unittest.common_utils
import
PytorchTestCase
,
skipIfNoCuda
,
skipIfkmeMark
from
.torchscript_consistency_impl
import
Functional
,
FunctionalFloat32Only
from
.torchscript_consistency_impl
import
Functional
,
FunctionalFloat32Only
...
@@ -11,6 +11,7 @@ class TestFunctionalFloat32(Functional, FunctionalFloat32Only, PytorchTestCase):
...
@@ -11,6 +11,7 @@ class TestFunctionalFloat32(Functional, FunctionalFloat32Only, PytorchTestCase):
@
skipIfNoCuda
@
skipIfNoCuda
@
skipIfkmeMark
class
TestFunctionalFloat64
(
Functional
,
PytorchTestCase
):
class
TestFunctionalFloat64
(
Functional
,
PytorchTestCase
):
dtype
=
torch
.
float64
dtype
=
torch
.
float64
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
test/torchaudio_unittest/functional/torchscript_consistency_impl.py
View file @
ffeba11a
...
@@ -585,22 +585,10 @@ class Functional(TempDirMixin, TestBaseMixin):
...
@@ -585,22 +585,10 @@ class Functional(TempDirMixin, TestBaseMixin):
tensor
=
common_utils
.
get_whitenoise
(
sample_rate
=
44100
)
tensor
=
common_utils
.
get_whitenoise
(
sample_rate
=
44100
)
self
.
_assert_consistency
(
func
,
(
tensor
,))
self
.
_assert_consistency
(
func
,
(
tensor
,))
@
common_utils
.
skipIfNoKaldi
def
test_compute_kaldi_pitch
(
self
):
if
self
.
dtype
!=
torch
.
float32
or
self
.
device
!=
torch
.
device
(
"cpu"
):
raise
unittest
.
SkipTest
(
"Only float32, cpu is supported."
)
def
func
(
tensor
):
sample_rate
:
float
=
44100.0
return
F
.
compute_kaldi_pitch
(
tensor
,
sample_rate
)
tensor
=
common_utils
.
get_whitenoise
(
sample_rate
=
44100
)
self
.
_assert_consistency
(
func
,
(
tensor
,))
def
test_resample_sinc
(
self
):
def
test_resample_sinc
(
self
):
def
func
(
tensor
):
def
func
(
tensor
):
sr1
,
sr2
=
16000
,
8000
sr1
,
sr2
=
16000
,
8000
return
F
.
resample
(
tensor
,
sr1
,
sr2
,
resampling_method
=
"sinc_interp
olatio
n"
)
return
F
.
resample
(
tensor
,
sr1
,
sr2
,
resampling_method
=
"sinc_interp
_han
n"
)
tensor
=
common_utils
.
get_whitenoise
(
sample_rate
=
16000
)
tensor
=
common_utils
.
get_whitenoise
(
sample_rate
=
16000
)
self
.
_assert_consistency
(
func
,
(
tensor
,))
self
.
_assert_consistency
(
func
,
(
tensor
,))
...
@@ -616,7 +604,9 @@ class Functional(TempDirMixin, TestBaseMixin):
...
@@ -616,7 +604,9 @@ class Functional(TempDirMixin, TestBaseMixin):
sr1
,
sr2
=
16000
,
8000
sr1
,
sr2
=
16000
,
8000
lowpass_filter_width
=
6
lowpass_filter_width
=
6
rolloff
=
0.99
rolloff
=
0.99
self
.
_assert_consistency
(
F
.
resample
,
(
tensor
,
sr1
,
sr2
,
lowpass_filter_width
,
rolloff
,
"kaiser_window"
,
beta
))
self
.
_assert_consistency
(
F
.
resample
,
(
tensor
,
sr1
,
sr2
,
lowpass_filter_width
,
rolloff
,
"sinc_interp_kaiser"
,
beta
)
)
def
test_phase_vocoder
(
self
):
def
test_phase_vocoder
(
self
):
tensor
=
torch
.
view_as_complex
(
torch
.
randn
(
2
,
1025
,
400
,
2
))
tensor
=
torch
.
view_as_complex
(
torch
.
randn
(
2
,
1025
,
400
,
2
))
...
@@ -756,6 +746,54 @@ class Functional(TempDirMixin, TestBaseMixin):
...
@@ -756,6 +746,54 @@ class Functional(TempDirMixin, TestBaseMixin):
specgram
=
torch
.
rand
(
num_channels
,
n_fft_bin
,
num_frames
,
dtype
=
self
.
complex_dtype
,
device
=
self
.
device
)
specgram
=
torch
.
rand
(
num_channels
,
n_fft_bin
,
num_frames
,
dtype
=
self
.
complex_dtype
,
device
=
self
.
device
)
self
.
_assert_consistency_complex
(
F
.
apply_beamforming
,
(
beamform_weights
,
specgram
))
self
.
_assert_consistency_complex
(
F
.
apply_beamforming
,
(
beamform_weights
,
specgram
))
@
common_utils
.
nested_params
(
[
"convolve"
,
"fftconvolve"
],
[
"full"
,
"valid"
,
"same"
],
)
def
test_convolve
(
self
,
fn
,
mode
):
leading_dims
=
(
2
,
3
,
2
)
L_x
,
L_y
=
32
,
55
x
=
torch
.
rand
(
*
leading_dims
,
L_x
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
y
=
torch
.
rand
(
*
leading_dims
,
L_y
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
self
.
_assert_consistency
(
getattr
(
F
,
fn
),
(
x
,
y
,
mode
))
@
common_utils
.
nested_params
([
True
,
False
])
def
test_add_noise
(
self
,
use_lengths
):
leading_dims
=
(
2
,
3
)
L
=
31
waveform
=
torch
.
rand
(
*
leading_dims
,
L
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
requires_grad
=
True
)
noise
=
torch
.
rand
(
*
leading_dims
,
L
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
requires_grad
=
True
)
if
use_lengths
:
lengths
=
torch
.
rand
(
*
leading_dims
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
requires_grad
=
True
)
else
:
lengths
=
None
snr
=
torch
.
rand
(
*
leading_dims
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
requires_grad
=
True
)
*
10
self
.
_assert_consistency
(
F
.
add_noise
,
(
waveform
,
noise
,
snr
,
lengths
))
@
common_utils
.
nested_params
([
True
,
False
])
def
test_speed
(
self
,
use_lengths
):
leading_dims
=
(
3
,
2
)
T
=
200
waveform
=
torch
.
rand
(
*
leading_dims
,
T
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
requires_grad
=
True
)
if
use_lengths
:
lengths
=
torch
.
randint
(
1
,
T
,
leading_dims
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
else
:
lengths
=
None
self
.
_assert_consistency
(
F
.
speed
,
(
waveform
,
1000
,
1.1
,
lengths
))
def
test_preemphasis
(
self
):
waveform
=
torch
.
rand
(
3
,
2
,
100
,
device
=
self
.
device
,
dtype
=
self
.
dtype
)
coeff
=
0.9
self
.
_assert_consistency
(
F
.
preemphasis
,
(
waveform
,
coeff
))
def
test_deemphasis
(
self
):
waveform
=
torch
.
rand
(
3
,
2
,
100
,
device
=
self
.
device
,
dtype
=
self
.
dtype
)
coeff
=
0.9
self
.
_assert_consistency
(
F
.
deemphasis
,
(
waveform
,
coeff
))
class
FunctionalFloat32Only
(
TestBaseMixin
):
class
FunctionalFloat32Only
(
TestBaseMixin
):
def
test_rnnt_loss
(
self
):
def
test_rnnt_loss
(
self
):
...
...
test/torchaudio_unittest/io/common.py
0 → 100644
View file @
ffeba11a
import
torchaudio
# If FFmpeg is 4.1 or older
# Tests that checks the number of output samples from OPUS fails
# They work on 4.2+
# Probably this commit fixed it.
# https://github.com/FFmpeg/FFmpeg/commit/18aea7bdd96b320a40573bccabea56afeccdd91c
def
lt42
():
ver
=
torchaudio
.
utils
.
ffmpeg_utils
.
get_versions
()[
"libavcodec"
]
# 5.1 libavcodec 59. 18.100
# 4.4 libavcodec 58.134.100
# 4.3 libavcodec 58. 91.100
# 4.2 libavcodec 58. 54.100
# 4.1 libavcodec 58. 35.100
return
ver
[
0
]
<
59
and
ver
[
1
]
<
54
test/torchaudio_unittest/io/effector_test.py
0 → 100644
View file @
ffeba11a
from
parameterized
import
parameterized
from
torchaudio.io
import
AudioEffector
from
torchaudio_unittest.common_utils
import
get_sinusoid
,
skipIfNoFFmpeg
,
TorchaudioTestCase
from
.common
import
lt42
@
skipIfNoFFmpeg
class
EffectorTest
(
TorchaudioTestCase
):
def
test_null
(
self
):
"""No effect and codec will return the same result"""
sample_rate
=
8000
frames_per_chunk
=
256
effector
=
AudioEffector
(
effect
=
None
,
format
=
None
)
original
=
get_sinusoid
(
n_channels
=
3
,
sample_rate
=
sample_rate
,
channels_first
=
False
)
# one-go
output
=
effector
.
apply
(
original
,
sample_rate
)
self
.
assertEqual
(
original
,
output
)
# streaming
for
i
,
chunk
in
enumerate
(
effector
.
stream
(
original
,
sample_rate
,
frames_per_chunk
)):
start
=
i
*
frames_per_chunk
end
=
(
i
+
1
)
*
frames_per_chunk
self
.
assertEqual
(
original
[
start
:
end
,
:],
chunk
)
@
parameterized
.
expand
(
[
(
"ogg"
,
"flac"
),
# flac only supports s16 and s32
(
"ogg"
,
"opus"
),
# opus only supports 48k Hz
(
"ogg"
,
"vorbis"
),
# vorbis only supports stereo
# ("ogg", "vorbis", 44100),
# this fails with small descrepancy; 441024 vs 441000
# TODO: investigate
(
"wav"
,
None
),
(
"wav"
,
"pcm_u8"
),
(
"mp3"
,
None
),
(
"mulaw"
,
None
,
44100
),
# mulaw is encoded without header
]
)
def
test_formats
(
self
,
format
,
encoder
,
sample_rate
=
8000
):
"""Formats (some with restrictions) just work without an issue in effector"""
effector
=
AudioEffector
(
format
=
format
,
encoder
=
encoder
)
original
=
get_sinusoid
(
n_channels
=
3
,
sample_rate
=
sample_rate
,
channels_first
=
False
)
output
=
effector
.
apply
(
original
,
sample_rate
)
# On 4.1 OPUS produces 8020 samples (extra 20)
# this has been fixed on 4.2+
if
encoder
==
"opus"
and
lt42
():
return
self
.
assertEqual
(
original
.
shape
,
output
.
shape
)
# Note
# MP3 adds padding which cannot be removed when the encoded data is written to
# file-like object without seek method.
# The number of padding is retrievable as `AVCoedcContext::initial_padding`
# https://ffmpeg.org/doxygen/4.1/structAVCodecContext.html#a8f95550ce04f236e9915516d04d3d1ab
# but this is not exposed yet.
# These "priming" samples have negative time stamp, so we can also add logic
# to discard them at decoding, however, as far as I checked, when data is loaded
# with StreamReader, the time stamp is reset. I tried options like avoid_negative_ts,
# https://ffmpeg.org/ffmpeg-formats.html
# but it made no difference. Perhaps this is because the information about negative
# timestamp is only available at encoding side, and it presumably is written to
# header file, but it is not happening somehow with file-like object.
# Need to investigate more to remove MP3 padding
if
format
==
"mp3"
:
return
for
chunk
in
effector
.
stream
(
original
,
sample_rate
,
frames_per_chunk
=
original
.
size
(
0
)):
self
.
assertEqual
(
original
.
shape
,
chunk
.
shape
)
@
parameterized
.
expand
([(
"loudnorm=I=-16:LRA=11:TP=-1.5"
,),
(
"volume=2"
,)])
def
test_effect
(
self
,
effect
):
sample_rate
=
8000
effector
=
AudioEffector
(
effect
=
effect
)
original
=
get_sinusoid
(
n_channels
=
3
,
sample_rate
=
sample_rate
,
channels_first
=
False
)
output
=
effector
.
apply
(
original
,
sample_rate
)
self
.
assertEqual
(
original
.
shape
,
output
.
shape
)
def
test_resample
(
self
):
"""Resample option allows to change the sampling rate"""
sample_rate
=
8000
output_sample_rate
=
16000
num_channels
=
3
effector
=
AudioEffector
(
effect
=
"lowpass"
)
original
=
get_sinusoid
(
n_channels
=
num_channels
,
sample_rate
=
sample_rate
,
channels_first
=
False
)
output
=
effector
.
apply
(
original
,
sample_rate
,
output_sample_rate
)
self
.
assertEqual
(
output
.
shape
,
[
output_sample_rate
,
num_channels
])
for
chunk
in
effector
.
stream
(
original
,
sample_rate
,
output_sample_rate
=
output_sample_rate
,
frames_per_chunk
=
output_sample_rate
):
self
.
assertEqual
(
chunk
.
shape
,
[
output_sample_rate
,
num_channels
])
test/torchaudio_unittest/io/playback_test.py
0 → 100644
View file @
ffeba11a
from
unittest.mock
import
patch
import
torch
from
parameterized
import
parameterized
from
torchaudio.io
import
play_audio
,
StreamWriter
from
torchaudio_unittest.common_utils
import
get_sinusoid
,
skipIfNoAudioDevice
,
skipIfNoMacOS
,
TorchaudioTestCase
@
skipIfNoAudioDevice
@
skipIfNoMacOS
class
PlaybackInterfaceTest
(
TorchaudioTestCase
):
@
parameterized
.
expand
([(
"uint8"
,),
(
"int16"
,),
(
"int32"
,),
(
"int64"
,),
(
"float32"
,),
(
"float64"
,)])
@
patch
.
object
(
StreamWriter
,
"write_audio_chunk"
)
def
test_playaudio
(
self
,
dtype
,
writeaudio_mock
):
"""Test playaudio function.
The patch object is used to check if the data is written
to the output device stream, without playing the actual audio.
"""
dtype
=
getattr
(
torch
,
dtype
)
sample_rate
=
8000
waveform
=
get_sinusoid
(
frequency
=
440
,
sample_rate
=
sample_rate
,
duration
=
1
,
# seconds
n_channels
=
1
,
dtype
=
dtype
,
device
=
"cpu"
,
channels_first
=
False
,
)
play_audio
(
waveform
,
sample_rate
=
sample_rate
)
writeaudio_mock
.
assert_called
()
@
parameterized
.
expand
(
[
# Invalid number of dimensions (!= 2)
(
"int16"
,
1
,
"audiotoolbox"
),
(
"int16"
,
3
,
"audiotoolbox"
),
# Invalid tensor type
(
"complex64"
,
2
,
"audiotoolbox"
),
# Invalid output device
(
"int16"
,
2
,
"audiotool"
),
]
)
@
patch
.
object
(
StreamWriter
,
"write_audio_chunk"
)
def
test_playaudio_invalid_options
(
self
,
dtype
,
ndim
,
device
,
writeaudio_mock
):
"""Test playaudio function raises error with invalid options."""
dtype
=
getattr
(
torch
,
dtype
)
sample_rate
=
8000
waveform
=
get_sinusoid
(
frequency
=
440
,
sample_rate
=
sample_rate
,
duration
=
1
,
# seconds
n_channels
=
1
,
dtype
=
dtype
,
device
=
"cpu"
,
channels_first
=
False
,
).
squeeze
()
for
_
in
range
(
ndim
-
1
):
waveform
=
waveform
.
unsqueeze
(
-
1
)
with
self
.
assertRaises
(
ValueError
):
play_audio
(
waveform
,
sample_rate
=
sample_rate
,
device
=
device
)
test/torchaudio_unittest/io/stream_reader_test.py
View file @
ffeba11a
import
io
import
torch
import
torch
import
torchaudio
import
torchaudio
from
parameterized
import
parameterized
,
parameterized_class
from
parameterized
import
parameterized
,
parameterized_class
from
torchaudio_unittest.common_utils
import
(
from
torchaudio_unittest.common_utils
import
(
disabledInCI
,
get_asset_path
,
get_asset_path
,
get_image
,
get_image
,
get_sinusoid
,
get_wav_data
,
get_wav_data
,
is_ffmpeg_available
,
is_ffmpeg_available
,
nested_params
,
nested_params
,
...
@@ -12,24 +16,68 @@ from torchaudio_unittest.common_utils import (
...
@@ -12,24 +16,68 @@ from torchaudio_unittest.common_utils import (
save_image
,
save_image
,
save_wav
,
save_wav
,
skipIfNoFFmpeg
,
skipIfNoFFmpeg
,
skipIfNoHWAccel
,
TempDirMixin
,
TempDirMixin
,
TorchaudioTestCase
,
TorchaudioTestCase
,
)
)
if
is_ffmpeg_available
():
if
is_ffmpeg_available
():
from
torchaudio.io
import
(
from
torchaudio.io
import
StreamReader
,
StreamWriter
StreamReader
,
from
torchaudio.io._stream_reader
import
(
StreamReaderSourceAudioStream
,
ChunkTensor
,
StreamReaderSourceStream
,
OutputAudioStream
,
StreamReaderSourceVideoStream
,
OutputVideoStream
,
SourceAudioStream
,
SourceStream
,
SourceVideoStream
,
)
)
@
skipIfNoFFmpeg
class
ChunkTensorTest
(
TorchaudioTestCase
):
def
test_chunktensor
(
self
):
"""ChunkTensor serves as a replacement of tensor"""
data
=
torch
.
randn
((
256
,
2
))
pts
=
16.0
c
=
ChunkTensor
(
data
,
pts
)
assert
c
.
pts
==
pts
self
.
assertEqual
(
c
,
data
)
# method
sum_
=
c
.
sum
()
assert
isinstance
(
sum_
,
torch
.
Tensor
)
self
.
assertEqual
(
sum_
,
data
.
sum
())
# function form
min_
=
torch
.
min
(
c
)
assert
isinstance
(
min_
,
torch
.
Tensor
)
self
.
assertEqual
(
min_
,
torch
.
min
(
data
))
# attribute
t
=
c
.
T
assert
isinstance
(
t
,
torch
.
Tensor
)
self
.
assertEqual
(
t
,
data
.
T
)
# in-place op
c
[
0
]
=
0
self
.
assertEqual
(
c
,
data
)
# pass to other C++ code
buffer
=
io
.
BytesIO
()
w
=
StreamWriter
(
buffer
,
format
=
"wav"
)
w
.
add_audio_stream
(
8000
,
2
)
with
w
.
open
():
w
.
write_audio_chunk
(
0
,
c
)
w
.
write_audio_chunk
(
0
,
c
,
c
.
pts
)
################################################################################
################################################################################
# Helper decorator and Mixin to duplicate the tests for fileobj
# Helper decorator and Mixin to duplicate the tests for fileobj
_media_source
=
parameterized_class
(
_media_source
=
parameterized_class
(
(
"test_type"
,),
(
"test_type"
,),
[(
"str"
,),
(
"fileobj"
,)
,
(
"tensor"
,)
],
[(
"str"
,),
(
"fileobj"
,)],
class_name_func
=
lambda
cls
,
_
,
params
:
f
'
{
cls
.
__name__
}
_
{
params
[
"test_type"
]
}
'
,
class_name_func
=
lambda
cls
,
_
,
params
:
f
'
{
cls
.
__name__
}
_
{
params
[
"test_type"
]
}
'
,
)
)
...
@@ -47,13 +95,6 @@ class _MediaSourceMixin:
...
@@ -47,13 +95,6 @@ class _MediaSourceMixin:
self
.
src
=
path
self
.
src
=
path
elif
self
.
test_type
==
"fileobj"
:
elif
self
.
test_type
==
"fileobj"
:
self
.
src
=
open
(
path
,
"rb"
)
self
.
src
=
open
(
path
,
"rb"
)
elif
self
.
test_type
==
"tensor"
:
with
open
(
path
,
"rb"
)
as
fileobj
:
data
=
fileobj
.
read
()
self
.
src
=
torch
.
frombuffer
(
data
,
dtype
=
torch
.
uint8
)
print
(
self
.
src
.
data_ptr
())
print
(
len
(
data
))
print
(
self
.
src
.
shape
)
return
self
.
src
return
self
.
src
def
tearDown
(
self
):
def
tearDown
(
self
):
...
@@ -112,7 +153,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
...
@@ -112,7 +153,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
base_metadata
=
{}
base_metadata
=
{}
expected
=
[
expected
=
[
StreamReader
SourceVideoStream
(
SourceVideoStream
(
media_type
=
"video"
,
media_type
=
"video"
,
codec
=
"h264"
,
codec
=
"h264"
,
codec_long_name
=
"H.264 / AVC / MPEG-4 AVC / MPEG-4 part 10"
,
codec_long_name
=
"H.264 / AVC / MPEG-4 AVC / MPEG-4 part 10"
,
...
@@ -129,7 +170,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
...
@@ -129,7 +170,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
height
=
180
,
height
=
180
,
frame_rate
=
25.0
,
frame_rate
=
25.0
,
),
),
StreamReader
SourceAudioStream
(
SourceAudioStream
(
media_type
=
"audio"
,
media_type
=
"audio"
,
codec
=
"aac"
,
codec
=
"aac"
,
codec_long_name
=
"AAC (Advanced Audio Coding)"
,
codec_long_name
=
"AAC (Advanced Audio Coding)"
,
...
@@ -145,7 +186,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
...
@@ -145,7 +186,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
sample_rate
=
8000.0
,
sample_rate
=
8000.0
,
num_channels
=
2
,
num_channels
=
2
,
),
),
StreamReader
SourceStream
(
SourceStream
(
media_type
=
"subtitle"
,
media_type
=
"subtitle"
,
codec
=
"mov_text"
,
codec
=
"mov_text"
,
codec_long_name
=
"MOV text"
,
codec_long_name
=
"MOV text"
,
...
@@ -158,7 +199,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
...
@@ -158,7 +199,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
"language"
:
"eng"
,
"language"
:
"eng"
,
},
},
),
),
StreamReader
SourceVideoStream
(
SourceVideoStream
(
media_type
=
"video"
,
media_type
=
"video"
,
codec
=
"h264"
,
codec
=
"h264"
,
codec_long_name
=
"H.264 / AVC / MPEG-4 AVC / MPEG-4 part 10"
,
codec_long_name
=
"H.264 / AVC / MPEG-4 AVC / MPEG-4 part 10"
,
...
@@ -175,7 +216,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
...
@@ -175,7 +216,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
height
=
270
,
height
=
270
,
frame_rate
=
29.97002997002997
,
frame_rate
=
29.97002997002997
,
),
),
StreamReader
SourceAudioStream
(
SourceAudioStream
(
media_type
=
"audio"
,
media_type
=
"audio"
,
codec
=
"aac"
,
codec
=
"aac"
,
codec_long_name
=
"AAC (Advanced Audio Coding)"
,
codec_long_name
=
"AAC (Advanced Audio Coding)"
,
...
@@ -191,7 +232,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
...
@@ -191,7 +232,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
sample_rate
=
16000.0
,
sample_rate
=
16000.0
,
num_channels
=
2
,
num_channels
=
2
,
),
),
StreamReader
SourceStream
(
SourceStream
(
media_type
=
"subtitle"
,
media_type
=
"subtitle"
,
codec
=
"mov_text"
,
codec
=
"mov_text"
,
codec_long_name
=
"MOV text"
,
codec_long_name
=
"MOV text"
,
...
@@ -208,6 +249,98 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
...
@@ -208,6 +249,98 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
output
=
[
s
.
get_src_stream_info
(
i
)
for
i
in
range
(
6
)]
output
=
[
s
.
get_src_stream_info
(
i
)
for
i
in
range
(
6
)]
assert
expected
==
output
assert
expected
==
output
def
test_output_info
(
self
):
s
=
StreamReader
(
self
.
get_src
())
s
.
add_audio_stream
(
-
1
)
s
.
add_audio_stream
(
-
1
,
filter_desc
=
"aresample=8000"
)
s
.
add_audio_stream
(
-
1
,
filter_desc
=
"aformat=sample_fmts=s16p"
)
s
.
add_video_stream
(
-
1
)
s
.
add_video_stream
(
-
1
,
filter_desc
=
"fps=10"
)
s
.
add_video_stream
(
-
1
,
filter_desc
=
"format=rgb24"
)
s
.
add_video_stream
(
-
1
,
filter_desc
=
"scale=w=160:h=90"
)
# Note:
# Somehow only FFmpeg 5 reports invalid video frame rate. (24576/0)
# FFmpeg 4 and 6 work fine.
# Perhaps this is a regression in FFmpeg or it could actually originate
# from other libraries.
# It consistently fails with FFmpeg installed via conda, so we change
# the value based on FFmpeg version.
ver
=
torchaudio
.
utils
.
ffmpeg_utils
.
get_versions
()[
"libavutil"
]
print
(
ver
)
major
,
minor
,
_
=
ver
if
major
==
57
:
video_frame_rate
=
-
1
else
:
video_frame_rate
=
30000
/
1001
print
(
video_frame_rate
)
expected
=
[
OutputAudioStream
(
source_index
=
4
,
filter_description
=
"anull"
,
media_type
=
"audio"
,
format
=
"fltp"
,
sample_rate
=
16000.0
,
num_channels
=
2
,
),
OutputAudioStream
(
source_index
=
4
,
filter_description
=
"aresample=8000"
,
media_type
=
"audio"
,
format
=
"fltp"
,
sample_rate
=
8000.0
,
num_channels
=
2
,
),
OutputAudioStream
(
source_index
=
4
,
filter_description
=
"aformat=sample_fmts=s16p"
,
media_type
=
"audio"
,
format
=
"s16p"
,
sample_rate
=
16000.0
,
num_channels
=
2
,
),
OutputVideoStream
(
source_index
=
3
,
filter_description
=
"null"
,
media_type
=
"video"
,
format
=
"yuv420p"
,
width
=
480
,
height
=
270
,
frame_rate
=
30000
/
1001
,
),
OutputVideoStream
(
source_index
=
3
,
filter_description
=
"fps=10"
,
media_type
=
"video"
,
format
=
"yuv420p"
,
width
=
480
,
height
=
270
,
frame_rate
=
10
,
),
OutputVideoStream
(
source_index
=
3
,
filter_description
=
"format=rgb24"
,
media_type
=
"video"
,
format
=
"rgb24"
,
width
=
480
,
height
=
270
,
frame_rate
=
30000
/
1001
,
),
OutputVideoStream
(
source_index
=
3
,
filter_description
=
"scale=w=160:h=90"
,
media_type
=
"video"
,
format
=
"yuv420p"
,
width
=
160
,
height
=
90
,
frame_rate
=
30000
/
1001
,
),
]
output
=
[
s
.
get_out_stream_info
(
i
)
for
i
in
range
(
s
.
num_out_streams
)]
assert
expected
==
output
def
test_id3tag
(
self
):
def
test_id3tag
(
self
):
"""get_metadata method can fetch id3tag properly"""
"""get_metadata method can fetch id3tag properly"""
s
=
StreamReader
(
self
.
get_src
(
"steam-train-whistle-daniel_simon.mp3"
))
s
=
StreamReader
(
self
.
get_src
(
"steam-train-whistle-daniel_simon.mp3"
))
...
@@ -418,15 +551,26 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
...
@@ -418,15 +551,26 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
if
i
>=
40
:
if
i
>=
40
:
break
break
def
test_seek
(
self
):
def
test_stream_requires_grad_false
(
self
):
"""Tensors produced by StreamReader are requires_grad=False"""
s
=
StreamReader
(
self
.
get_src
())
s
.
add_basic_audio_stream
(
frames_per_chunk
=
2000
)
s
.
add_basic_video_stream
(
frames_per_chunk
=
15
)
s
.
fill_buffer
()
audio
,
video
=
s
.
pop_chunks
()
assert
not
audio
.
_elem
.
requires_grad
assert
not
video
.
_elem
.
requires_grad
@
parameterized
.
expand
([
"key"
,
"any"
,
"precise"
])
def
test_seek
(
self
,
mode
):
"""Calling `seek` multiple times should not segfault"""
"""Calling `seek` multiple times should not segfault"""
s
=
StreamReader
(
self
.
get_src
())
s
=
StreamReader
(
self
.
get_src
())
for
i
in
range
(
10
):
for
i
in
range
(
10
):
s
.
seek
(
i
)
s
.
seek
(
i
,
mode
)
for
_
in
range
(
0
):
for
_
in
range
(
0
):
s
.
seek
(
0
)
s
.
seek
(
0
,
mode
)
for
i
in
range
(
10
,
0
,
-
1
):
for
i
in
range
(
10
,
0
,
-
1
):
s
.
seek
(
i
)
s
.
seek
(
i
,
mode
)
def
test_seek_negative
(
self
):
def
test_seek_negative
(
self
):
"""Calling `seek` with negative value should raise an exception"""
"""Calling `seek` with negative value should raise an exception"""
...
@@ -434,6 +578,232 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
...
@@ -434,6 +578,232 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
with
self
.
assertRaises
(
RuntimeError
):
with
self
.
assertRaises
(
RuntimeError
):
s
.
seek
(
-
1.0
)
s
.
seek
(
-
1.0
)
def
test_seek_invalid_mode
(
self
):
"""Calling `seek` with an invalid model should raise an exception"""
s
=
StreamReader
(
self
.
get_src
())
with
self
.
assertRaises
(
ValueError
):
s
.
seek
(
10
,
"magic_seek"
)
@
parameterized
.
expand
(
[
# Test keyframe seek
# The source mp4 video has two key frames the first frame and 203rd frame at 8.08 second.
# If the seek time stamp is smaller than 8.08, it will seek into the first frame at 0.0 second.
(
"nasa_13013.mp4"
,
"key"
,
0.2
,
(
0
,
slice
(
None
))),
(
"nasa_13013.mp4"
,
"key"
,
8.04
,
(
0
,
slice
(
None
))),
(
"nasa_13013.mp4"
,
"key"
,
8.08
,
(
0
,
slice
(
202
,
None
))),
(
"nasa_13013.mp4"
,
"key"
,
8.12
,
(
0
,
slice
(
202
,
None
))),
# The source avi video has one keyframe every twelve frames 0, 12, 24,.. or every 0.4004 seconds.
# if we seek to a time stamp smaller than 0.4004 it will seek into the first frame at 0.0 second.
(
"nasa_13013.avi"
,
"key"
,
0.2
,
(
0
,
slice
(
None
))),
(
"nasa_13013.avi"
,
"key"
,
1.01
,
(
0
,
slice
(
24
,
None
))),
(
"nasa_13013.avi"
,
"key"
,
7.37
,
(
0
,
slice
(
216
,
None
))),
(
"nasa_13013.avi"
,
"key"
,
7.7
,
(
0
,
slice
(
216
,
None
))),
# Test precise seek
(
"nasa_13013.mp4"
,
"precise"
,
0.0
,
(
0
,
slice
(
None
))),
(
"nasa_13013.mp4"
,
"precise"
,
0.2
,
(
0
,
slice
(
5
,
None
))),
(
"nasa_13013.mp4"
,
"precise"
,
8.04
,
(
0
,
slice
(
201
,
None
))),
(
"nasa_13013.mp4"
,
"precise"
,
8.08
,
(
0
,
slice
(
202
,
None
))),
(
"nasa_13013.mp4"
,
"precise"
,
8.12
,
(
0
,
slice
(
203
,
None
))),
(
"nasa_13013.avi"
,
"precise"
,
0.0
,
(
0
,
slice
(
None
))),
(
"nasa_13013.avi"
,
"precise"
,
0.2
,
(
0
,
slice
(
1
,
None
))),
(
"nasa_13013.avi"
,
"precise"
,
8.1
,
(
0
,
slice
(
238
,
None
))),
(
"nasa_13013.avi"
,
"precise"
,
8.14
,
(
0
,
slice
(
239
,
None
))),
(
"nasa_13013.avi"
,
"precise"
,
8.17
,
(
0
,
slice
(
240
,
None
))),
# Test precise seek on video with missing PTS
(
"RATRACE_wave_f_nm_np1_fr_goo_37.avi"
,
"precise"
,
0.0
,
(
0
,
slice
(
None
))),
(
"RATRACE_wave_f_nm_np1_fr_goo_37.avi"
,
"precise"
,
0.2
,
(
0
,
slice
(
4
,
None
))),
(
"RATRACE_wave_f_nm_np1_fr_goo_37.avi"
,
"precise"
,
0.3
,
(
0
,
slice
(
7
,
None
))),
# Test any seek
# The source avi video has one keyframe every twelve frames 0, 12, 24,.. or every 0.4004 seconds.
(
"nasa_13013.avi"
,
"any"
,
0.0
,
(
0
,
slice
(
None
))),
(
"nasa_13013.avi"
,
"any"
,
0.56
,
(
0
,
slice
(
12
,
None
))),
(
"nasa_13013.avi"
,
"any"
,
7.77
,
(
0
,
slice
(
228
,
None
))),
(
"nasa_13013.avi"
,
"any"
,
0.2002
,
(
11
,
slice
(
12
,
None
))),
(
"nasa_13013.avi"
,
"any"
,
0.233567
,
(
10
,
slice
(
12
,
None
))),
(
"nasa_13013.avi"
,
"any"
,
0.266933
,
(
9
,
slice
(
12
,
None
))),
]
)
def
test_seek_modes
(
self
,
src
,
mode
,
seek_time
,
ref_indices
):
"""We expect the following behaviour from the diferent kinds of seek:
- `key`: the reader will seek to the first keyframe from the timestamp given
- `precise`: the reader will seek to the first keyframe from the timestamp given
and start decoding from that position until the given timestmap (discarding all frames in between)
- `any`: the reader will seek to the colsest frame to the timestamp
given but if this is not a keyframe, the content will be the delta from other frames
To thest this behaviour we can parameterize the test with the tupple ref_indices. ref_indices[0]
is the expected index on the frames list decoded after seek and ref_indices[1] is exepected index for
the list of all frames decoded from the begining (reference frames). This test checks if
the reference frame at index ref_indices[1] is the same as ref_indices[0]. Plese note that with `any`
and `key` seek we only compare keyframes, but with `precise` seek we can compare any frame content.
"""
# Using the first video stream (which is not default video stream)
stream_index
=
0
# Decode all frames for reference
src_bin
=
self
.
get_src
(
src
)
s
=
StreamReader
(
src_bin
)
s
.
add_basic_video_stream
(
-
1
,
stream_index
=
stream_index
)
s
.
process_all_packets
()
(
ref_frames
,)
=
s
.
pop_chunks
()
s
.
seek
(
seek_time
,
mode
=
mode
)
s
.
process_all_packets
()
(
frame
,)
=
s
.
pop_chunks
()
hyp_index
,
ref_index
=
ref_indices
hyp
,
ref
=
frame
[
hyp_index
:],
ref_frames
[
ref_index
]
print
(
hyp
.
shape
,
ref
.
shape
)
self
.
assertEqual
(
hyp
,
ref
)
@
parameterized
.
expand
(
[
(
"nasa_13013.mp4"
,
[
195
,
3
,
270
,
480
]),
# RATRACE does not have valid PTS metadata.
(
"RATRACE_wave_f_nm_np1_fr_goo_37.avi"
,
[
36
,
3
,
240
,
560
]),
]
)
def
test_change_fps
(
self
,
src
,
shape
):
"""Can change the FPS of videos"""
tgt_frame_rate
=
15
s
=
StreamReader
(
self
.
get_src
(
src
))
info
=
s
.
get_src_stream_info
(
s
.
default_video_stream
)
assert
info
.
frame_rate
!=
tgt_frame_rate
s
.
add_basic_video_stream
(
frames_per_chunk
=-
1
,
frame_rate
=
tgt_frame_rate
)
s
.
process_all_packets
()
(
chunk
,)
=
s
.
pop_chunks
()
assert
chunk
.
shape
==
torch
.
Size
(
shape
)
def
test_invalid_chunk_option
(
self
):
"""Passing invalid `frames_per_chunk` and `buffer_chunk_size` raises error"""
s
=
StreamReader
(
self
.
get_src
())
for
fpc
,
bcs
in
((
0
,
3
),
(
3
,
0
),
(
-
2
,
3
),
(
3
,
-
2
)):
with
self
.
assertRaises
(
RuntimeError
):
s
.
add_audio_stream
(
frames_per_chunk
=
fpc
,
buffer_chunk_size
=
bcs
)
with
self
.
assertRaises
(
RuntimeError
):
s
.
add_video_stream
(
frames_per_chunk
=
fpc
,
buffer_chunk_size
=
bcs
)
def
test_unchunked_stream
(
self
):
"""`frames_per_chunk=-1` disable chunking.
When chunking is disabled, frames contained in one AVFrame become one chunk.
For video, that is always one frame, but for audio, it depends.
"""
s
=
StreamReader
(
self
.
get_src
())
s
.
add_video_stream
(
frames_per_chunk
=-
1
,
buffer_chunk_size
=
10000
)
s
.
add_audio_stream
(
frames_per_chunk
=-
1
,
buffer_chunk_size
=
10000
)
s
.
process_all_packets
()
video
,
audio
=
s
.
pop_chunks
()
assert
video
.
shape
==
torch
.
Size
([
390
,
3
,
270
,
480
])
assert
audio
.
shape
==
torch
.
Size
([
208896
,
2
])
@
parameterized
.
expand
([(
1
,),
(
3
,),
(
5
,),
(
10
,)])
def
test_frames_per_chunk
(
self
,
fpc
):
"""Changing frames_per_chunk does not change the returned content"""
src
=
self
.
get_src
()
s
=
StreamReader
(
src
)
s
.
add_video_stream
(
frames_per_chunk
=-
1
,
buffer_chunk_size
=-
1
)
s
.
add_audio_stream
(
frames_per_chunk
=-
1
,
buffer_chunk_size
=-
1
)
s
.
process_all_packets
()
ref_video
,
ref_audio
=
s
.
pop_chunks
()
if
self
.
test_type
==
"fileobj"
:
src
.
seek
(
0
)
s
=
StreamReader
(
src
)
s
.
add_video_stream
(
frames_per_chunk
=
fpc
,
buffer_chunk_size
=-
1
)
s
.
add_audio_stream
(
frames_per_chunk
=
fpc
,
buffer_chunk_size
=-
1
)
chunks
=
list
(
s
.
stream
())
video_chunks
=
torch
.
cat
([
c
[
0
]
for
c
in
chunks
if
c
[
0
]
is
not
None
])
audio_chunks
=
torch
.
cat
([
c
[
1
]
for
c
in
chunks
if
c
[
1
]
is
not
None
])
self
.
assertEqual
(
ref_video
,
video_chunks
)
self
.
assertEqual
(
ref_audio
,
audio_chunks
)
def
test_buffer_chunk_size
(
self
):
"""`buffer_chunk_size=-1` does not drop frames."""
src
=
self
.
get_src
()
s
=
StreamReader
(
src
)
s
.
add_video_stream
(
frames_per_chunk
=
30
,
buffer_chunk_size
=-
1
)
s
.
add_audio_stream
(
frames_per_chunk
=
16000
,
buffer_chunk_size
=-
1
)
s
.
process_all_packets
()
for
_
in
range
(
13
):
video
,
audio
=
s
.
pop_chunks
()
assert
video
.
shape
==
torch
.
Size
([
30
,
3
,
270
,
480
])
assert
audio
.
shape
==
torch
.
Size
([
16000
,
2
])
video
,
audio
=
s
.
pop_chunks
()
assert
video
is
None
assert
audio
.
shape
==
torch
.
Size
([
896
,
2
])
if
self
.
test_type
==
"fileobj"
:
src
.
seek
(
0
)
s
=
StreamReader
(
src
)
s
.
add_video_stream
(
frames_per_chunk
=
30
,
buffer_chunk_size
=
3
)
s
.
add_audio_stream
(
frames_per_chunk
=
16000
,
buffer_chunk_size
=
3
)
s
.
process_all_packets
()
for
_
in
range
(
2
):
video
,
audio
=
s
.
pop_chunks
()
assert
video
.
shape
==
torch
.
Size
([
30
,
3
,
270
,
480
])
assert
audio
.
shape
==
torch
.
Size
([
16000
,
2
])
video
,
audio
=
s
.
pop_chunks
()
assert
video
.
shape
==
torch
.
Size
([
30
,
3
,
270
,
480
])
assert
audio
.
shape
==
torch
.
Size
([
896
,
2
])
@
parameterized
.
expand
([(
1
,),
(
3
,),
(
5
,),
(
10
,)])
def
test_video_pts
(
self
,
fpc
):
"""PTS values of the first frame are reported in .pts attribute"""
rate
,
num_frames
=
30000
/
1001
,
390
ref_pts
=
[
i
/
rate
for
i
in
range
(
0
,
num_frames
,
fpc
)]
s
=
StreamReader
(
self
.
get_src
())
s
.
add_video_stream
(
fpc
)
pts
=
[
video
.
pts
for
video
,
in
s
.
stream
()]
self
.
assertEqual
(
pts
,
ref_pts
)
@
parameterized
.
expand
([(
256
,),
(
512
,),
(
1024
,),
(
4086
,)])
def
test_audio_pts
(
self
,
fpc
):
"""PTS values of the first frame are reported in .pts attribute"""
rate
,
num_frames
=
16000
,
208896
ref_pts
=
[
i
/
rate
for
i
in
range
(
0
,
num_frames
,
fpc
)]
s
=
StreamReader
(
self
.
get_src
())
s
.
add_audio_stream
(
fpc
,
buffer_chunk_size
=-
1
)
pts
=
[
audio
.
pts
for
audio
,
in
s
.
stream
()]
self
.
assertEqual
(
pts
,
ref_pts
)
def
test_pts_unchunked_process_all
(
self
):
"""PTS is zero when loading the entire media with unchunked buffer"""
s
=
StreamReader
(
self
.
get_src
())
s
.
add_audio_stream
(
-
1
,
buffer_chunk_size
=-
1
)
s
.
add_video_stream
(
-
1
,
buffer_chunk_size
=-
1
)
s
.
process_all_packets
()
audio
,
video
=
s
.
pop_chunks
()
assert
audio
.
pts
==
0.0
assert
video
.
pts
==
0.0
assert
audio
.
size
(
0
)
==
208896
assert
video
.
size
(
0
)
==
390
def
test_pts_unchunked
(
self
):
"""PTS grows proportionally to the number of frames decoded"""
s
=
StreamReader
(
self
.
get_src
())
s
.
add_audio_stream
(
-
1
,
buffer_chunk_size
=-
1
)
s
.
add_video_stream
(
-
1
,
buffer_chunk_size
=-
1
)
num_audio_frames
,
num_video_frames
=
0
,
0
while
num_audio_frames
<
208896
and
num_video_frames
<
390
:
s
.
process_packet
()
audio
,
video
=
s
.
pop_chunks
()
if
audio
is
None
and
video
is
None
:
continue
if
audio
is
not
None
:
assert
audio
.
pts
==
num_audio_frames
/
16000
num_audio_frames
+=
audio
.
size
(
0
)
if
video
is
not
None
:
assert
video
.
pts
==
num_video_frames
*
1001
/
30000
num_video_frames
+=
video
.
size
(
0
)
def
_to_fltp
(
original
):
def
_to_fltp
(
original
):
"""Convert Tensor to float32 with value range [-1, 1]"""
"""Convert Tensor to float32 with value range [-1, 1]"""
...
@@ -493,11 +863,84 @@ class StreamReaderAudioTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase)
...
@@ -493,11 +863,84 @@ class StreamReaderAudioTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase)
if
self
.
test_type
==
"fileobj"
:
if
self
.
test_type
==
"fileobj"
:
src
.
seek
(
0
)
src
.
seek
(
0
)
self
.
_test_wav
(
src
,
original
,
fmt
=
None
)
self
.
_test_wav
(
src
,
original
,
fmt
=
None
)
# convert to float32
expected
=
_to_fltp
(
original
)
def
test_audio_stream_format
(
self
):
if
self
.
test_type
==
"fileobj"
:
"`format` argument properly changes the sample format of decoded audio"
src
.
seek
(
0
)
num_channels
=
2
self
.
_test_wav
(
src
,
expected
,
fmt
=
"fltp"
)
src
,
s32
=
self
.
get_src
(
8000
,
dtype
=
"int32"
,
num_channels
=
num_channels
)
args
=
{
"num_channels"
:
num_channels
,
"normalize"
:
False
,
"channels_first"
:
False
,
"num_frames"
:
1
<<
16
,
}
u8
=
get_wav_data
(
"uint8"
,
**
args
)
s16
=
get_wav_data
(
"int16"
,
**
args
)
s64
=
s32
.
to
(
torch
.
int64
)
*
(
1
<<
32
)
f32
=
get_wav_data
(
"float32"
,
**
args
)
f64
=
get_wav_data
(
"float64"
,
**
args
)
s
=
StreamReader
(
src
)
s
.
add_basic_audio_stream
(
frames_per_chunk
=-
1
,
format
=
"u8"
)
s
.
add_basic_audio_stream
(
frames_per_chunk
=-
1
,
format
=
"u8p"
)
s
.
add_basic_audio_stream
(
frames_per_chunk
=-
1
,
format
=
"s16"
)
s
.
add_basic_audio_stream
(
frames_per_chunk
=-
1
,
format
=
"s16p"
)
s
.
add_basic_audio_stream
(
frames_per_chunk
=-
1
,
format
=
"s32"
)
s
.
add_basic_audio_stream
(
frames_per_chunk
=-
1
,
format
=
"s32p"
)
s
.
add_basic_audio_stream
(
frames_per_chunk
=-
1
,
format
=
"s64"
)
s
.
add_basic_audio_stream
(
frames_per_chunk
=-
1
,
format
=
"s64p"
)
s
.
add_basic_audio_stream
(
frames_per_chunk
=-
1
,
format
=
"flt"
)
s
.
add_basic_audio_stream
(
frames_per_chunk
=-
1
,
format
=
"fltp"
)
s
.
add_basic_audio_stream
(
frames_per_chunk
=-
1
,
format
=
"dbl"
)
s
.
add_basic_audio_stream
(
frames_per_chunk
=-
1
,
format
=
"dblp"
)
s
.
process_all_packets
()
chunks
=
s
.
pop_chunks
()
self
.
assertEqual
(
chunks
[
0
],
u8
,
atol
=
1
,
rtol
=
0
)
self
.
assertEqual
(
chunks
[
1
],
u8
,
atol
=
1
,
rtol
=
0
)
self
.
assertEqual
(
chunks
[
2
],
s16
)
self
.
assertEqual
(
chunks
[
3
],
s16
)
self
.
assertEqual
(
chunks
[
4
],
s32
)
self
.
assertEqual
(
chunks
[
5
],
s32
)
self
.
assertEqual
(
chunks
[
6
],
s64
)
self
.
assertEqual
(
chunks
[
7
],
s64
)
self
.
assertEqual
(
chunks
[
8
],
f32
)
self
.
assertEqual
(
chunks
[
9
],
f32
)
self
.
assertEqual
(
chunks
[
10
],
f64
)
self
.
assertEqual
(
chunks
[
11
],
f64
)
@
nested_params
([
4000
,
16000
])
def
test_basic_audio_stream_sample_rate
(
self
,
sr
):
"""`sample_rate` argument changes the sample_rate of decoded audio"""
src_num_channels
,
src_sr
=
2
,
8000
data
=
get_sinusoid
(
sample_rate
=
src_sr
,
n_channels
=
src_num_channels
,
channels_first
=
False
)
path
=
self
.
get_temp_path
(
"ref.wav"
)
save_wav
(
path
,
data
,
src_sr
,
channels_first
=
False
)
s
=
StreamReader
(
path
)
s
.
add_basic_audio_stream
(
frames_per_chunk
=-
1
,
format
=
"flt"
,
sample_rate
=
sr
)
self
.
assertEqual
(
s
.
get_src_stream_info
(
0
).
sample_rate
,
src_sr
)
self
.
assertEqual
(
s
.
get_out_stream_info
(
0
).
sample_rate
,
sr
)
s
.
process_all_packets
()
(
chunks
,)
=
s
.
pop_chunks
()
self
.
assertEqual
(
chunks
.
shape
,
[
sr
,
src_num_channels
])
@
nested_params
([
1
,
2
,
3
,
8
,
16
])
def
test_basic_audio_stream_num_channels
(
self
,
num_channels
):
"""`sample_rate` argument changes the number of channels of decoded audio"""
src_num_channels
,
sr
=
2
,
8000
data
=
get_sinusoid
(
sample_rate
=
sr
,
n_channels
=
src_num_channels
,
channels_first
=
False
)
path
=
self
.
get_temp_path
(
"ref.wav"
)
save_wav
(
path
,
data
,
sr
,
channels_first
=
False
)
s
=
StreamReader
(
path
)
s
.
add_basic_audio_stream
(
frames_per_chunk
=-
1
,
format
=
"flt"
,
num_channels
=
num_channels
)
self
.
assertEqual
(
s
.
get_src_stream_info
(
0
).
num_channels
,
src_num_channels
)
self
.
assertEqual
(
s
.
get_out_stream_info
(
0
).
num_channels
,
num_channels
)
s
.
process_all_packets
()
(
chunks
,)
=
s
.
pop_chunks
()
self
.
assertEqual
(
chunks
.
shape
,
[
sr
,
num_channels
])
@
nested_params
(
@
nested_params
(
[
"int16"
,
"uint8"
,
"int32"
],
# "float", "double", "int64"]
[
"int16"
,
"uint8"
,
"int32"
],
# "float", "double", "int64"]
...
@@ -630,23 +1073,192 @@ class StreamReaderImageTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase)
...
@@ -630,23 +1073,192 @@ class StreamReaderImageTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase)
rgb
=
torch
.
empty
(
1
,
3
,
256
,
256
,
dtype
=
torch
.
uint8
)
rgb
=
torch
.
empty
(
1
,
3
,
256
,
256
,
dtype
=
torch
.
uint8
)
rgb
[
0
,
0
]
=
torch
.
arange
(
256
,
dtype
=
torch
.
uint8
).
reshape
([
1
,
-
1
])
rgb
[
0
,
0
]
=
torch
.
arange
(
256
,
dtype
=
torch
.
uint8
).
reshape
([
1
,
-
1
])
rgb
[
0
,
1
]
=
torch
.
arange
(
256
,
dtype
=
torch
.
uint8
).
reshape
([
-
1
,
1
])
rgb
[
0
,
1
]
=
torch
.
arange
(
256
,
dtype
=
torch
.
uint8
).
reshape
([
-
1
,
1
])
alpha
=
torch
.
full
((
1
,
1
,
256
,
256
),
255
,
dtype
=
torch
.
uint8
)
for
i
in
range
(
256
):
for
i
in
range
(
256
):
rgb
[
0
,
2
]
=
i
rgb
[
0
,
2
]
=
i
path
=
self
.
get_temp_path
(
f
"ref_
{
i
}
.png"
)
path
=
self
.
get_temp_path
(
f
"ref_
{
i
}
.png"
)
save_image
(
path
,
rgb
[
0
],
mode
=
"RGB"
)
save_image
(
path
,
rgb
[
0
],
mode
=
"RGB"
)
rgb16
=
((
rgb
.
to
(
torch
.
int32
)
-
128
)
<<
8
).
to
(
torch
.
int16
)
yuv
=
rgb_to_yuv_ccir
(
rgb
)
yuv
=
rgb_to_yuv_ccir
(
rgb
)
yuv16
=
yuv
.
to
(
torch
.
int16
)
*
4
bgr
=
rgb
[:,
[
2
,
1
,
0
],
:,
:]
bgr
=
rgb
[:,
[
2
,
1
,
0
],
:,
:]
gray
=
rgb_to_gray
(
rgb
)
gray
=
rgb_to_gray
(
rgb
)
argb
=
torch
.
cat
([
alpha
,
rgb
],
dim
=
1
)
rgba
=
torch
.
cat
([
rgb
,
alpha
],
dim
=
1
)
abgr
=
torch
.
cat
([
alpha
,
bgr
],
dim
=
1
)
bgra
=
torch
.
cat
([
bgr
,
alpha
],
dim
=
1
)
s
=
StreamReader
(
path
)
s
=
StreamReader
(
path
)
s
.
add_basic_video_stream
(
frames_per_chunk
=-
1
,
format
=
"yuv444p"
)
s
.
add_basic_video_stream
(
frames_per_chunk
=-
1
,
format
=
"yuv444p"
)
s
.
add_basic_video_stream
(
frames_per_chunk
=-
1
,
format
=
"yuv420p"
)
s
.
add_basic_video_stream
(
frames_per_chunk
=-
1
,
format
=
"nv12"
)
s
.
add_basic_video_stream
(
frames_per_chunk
=-
1
,
format
=
"rgb24"
)
s
.
add_basic_video_stream
(
frames_per_chunk
=-
1
,
format
=
"rgb24"
)
s
.
add_basic_video_stream
(
frames_per_chunk
=-
1
,
format
=
"bgr24"
)
s
.
add_basic_video_stream
(
frames_per_chunk
=-
1
,
format
=
"bgr24"
)
s
.
add_basic_video_stream
(
frames_per_chunk
=-
1
,
format
=
"gray8"
)
s
.
add_basic_video_stream
(
frames_per_chunk
=-
1
,
format
=
"gray8"
)
s
.
add_basic_video_stream
(
frames_per_chunk
=-
1
,
format
=
"rgb48le"
)
s
.
add_basic_video_stream
(
frames_per_chunk
=-
1
,
format
=
"argb"
)
s
.
add_basic_video_stream
(
frames_per_chunk
=-
1
,
format
=
"rgba"
)
s
.
add_basic_video_stream
(
frames_per_chunk
=-
1
,
format
=
"abgr"
)
s
.
add_basic_video_stream
(
frames_per_chunk
=-
1
,
format
=
"bgra"
)
s
.
add_basic_video_stream
(
frames_per_chunk
=-
1
,
format
=
"yuv420p10le"
)
s
.
process_all_packets
()
s
.
process_all_packets
()
output_yuv
,
output_rgb
,
output_bgr
,
output_gray
=
s
.
pop_chunks
()
chunks
=
s
.
pop_chunks
()
self
.
assertEqual
(
yuv
,
output_yuv
,
atol
=
1
,
rtol
=
0
)
self
.
assertEqual
(
chunks
[
0
],
yuv
,
atol
=
1
,
rtol
=
0
)
self
.
assertEqual
(
rgb
,
output_rgb
,
atol
=
0
,
rtol
=
0
)
self
.
assertEqual
(
chunks
[
1
],
yuv
,
atol
=
1
,
rtol
=
0
)
self
.
assertEqual
(
bgr
,
output_bgr
,
atol
=
0
,
rtol
=
0
)
self
.
assertEqual
(
chunks
[
2
],
yuv
,
atol
=
1
,
rtol
=
0
)
self
.
assertEqual
(
gray
,
output_gray
,
atol
=
1
,
rtol
=
0
)
self
.
assertEqual
(
chunks
[
3
],
rgb
,
atol
=
0
,
rtol
=
0
)
self
.
assertEqual
(
chunks
[
4
],
bgr
,
atol
=
0
,
rtol
=
0
)
self
.
assertEqual
(
chunks
[
5
],
gray
,
atol
=
1
,
rtol
=
0
)
self
.
assertEqual
(
chunks
[
6
],
rgb16
,
atol
=
256
,
rtol
=
0
)
self
.
assertEqual
(
chunks
[
7
],
argb
,
atol
=
0
,
rtol
=
0
)
self
.
assertEqual
(
chunks
[
8
],
rgba
,
atol
=
0
,
rtol
=
0
)
self
.
assertEqual
(
chunks
[
9
],
abgr
,
atol
=
0
,
rtol
=
0
)
self
.
assertEqual
(
chunks
[
10
],
bgra
,
atol
=
0
,
rtol
=
0
)
self
.
assertEqual
(
chunks
[
11
],
yuv16
,
atol
=
4
,
rtol
=
0
)
@
skipIfNoHWAccel
(
"h264_cuvid"
)
class
CuvidHWAccelInterfaceTest
(
TorchaudioTestCase
):
def
test_dup_hw_acel
(
self
):
"""Specifying the same source stream with and without HW accel should fail (instead of segfault later)"""
src
=
get_asset_path
(
"nasa_13013.mp4"
)
r
=
StreamReader
(
src
)
r
.
add_video_stream
(
-
1
,
decoder
=
"h264_cuvid"
)
with
self
.
assertRaises
(
RuntimeError
):
r
.
add_video_stream
(
-
1
,
decoder
=
"h264_cuvid"
,
hw_accel
=
"cuda"
)
r
=
StreamReader
(
src
)
r
.
add_video_stream
(
-
1
,
decoder
=
"h264_cuvid"
,
hw_accel
=
"cuda"
)
with
self
.
assertRaises
(
RuntimeError
):
r
.
add_video_stream
(
-
1
,
decoder
=
"h264_cuvid"
)
@
_media_source
class
CudaDecoderTest
(
_MediaSourceMixin
,
TempDirMixin
,
TorchaudioTestCase
):
def
_test_decode
(
self
,
decoder
:
str
,
src_path
:
str
,
height
:
int
,
width
:
int
,
ref_num_frames
:
int
,
hw_accel
=
None
,
decoder_option
=
None
,
dtype
:
torch
.
dtype
=
torch
.
uint8
,
):
src
=
self
.
get_src
(
get_asset_path
(
src_path
))
r
=
StreamReader
(
src
)
r
.
add_video_stream
(
10
,
decoder
=
decoder
,
decoder_option
=
decoder_option
,
hw_accel
=
hw_accel
)
num_frames
=
0
for
(
chunk
,)
in
r
.
stream
():
self
.
assertEqual
(
chunk
.
device
,
torch
.
device
(
hw_accel
or
"cpu"
))
self
.
assertEqual
(
chunk
.
dtype
,
dtype
)
self
.
assertEqual
(
chunk
.
shape
,
torch
.
Size
([
10
,
3
,
height
,
width
]))
num_frames
+=
chunk
.
size
(
0
)
assert
num_frames
==
ref_num_frames
@
skipIfNoHWAccel
(
"h264_cuvid"
)
def
test_h264_cuvid
(
self
):
"""GPU decoder works for H264"""
self
.
_test_decode
(
"h264_cuvid"
,
"nasa_13013.mp4"
,
270
,
480
,
390
)
@
skipIfNoHWAccel
(
"h264_cuvid"
)
def
test_h264_cuvid_hw_accel
(
self
):
"""GPU decoder works for H264 with HW acceleration, and put the frames on CUDA tensor"""
self
.
_test_decode
(
"h264_cuvid"
,
"nasa_13013.mp4"
,
270
,
480
,
390
,
hw_accel
=
"cuda:0"
)
@
skipIfNoHWAccel
(
"h264_cuvid"
)
def
test_h264_cuvid_hw_accel_resize
(
self
):
"""GPU decoder works for H264 with HW acceleration and resize option"""
w
,
h
=
240
,
136
self
.
_test_decode
(
"h264_cuvid"
,
"nasa_13013.mp4"
,
h
,
w
,
390
,
hw_accel
=
"cuda:0"
,
decoder_option
=
{
"resize"
:
f
"
{
w
}
x
{
h
}
"
}
)
@
skipIfNoHWAccel
(
"h264_cuvid"
)
def
test_h264_cuvid_hw_accel_crop
(
self
):
"""GPU decoder works for H264 with HW acceleration and crop option"""
top
,
bottom
,
left
,
right
=
3
,
5
,
7
,
9
self
.
_test_decode
(
"h264_cuvid"
,
"nasa_13013.mp4"
,
262
,
464
,
390
,
hw_accel
=
"cuda:0"
,
decoder_option
=
{
"crop"
:
f
"
{
top
}
x
{
bottom
}
x
{
left
}
x
{
right
}
"
},
)
@
skipIfNoHWAccel
(
"hevc_cuvid"
)
def
test_hevc_cuvid
(
self
):
"""GPU decoder works for H265/HEVC"""
self
.
_test_decode
(
"hevc_cuvid"
,
"testsrc.hevc"
,
144
,
256
,
300
)
@
skipIfNoHWAccel
(
"hevc_cuvid"
)
def
test_hevc_cuvid_hw_accel
(
self
):
"""GPU decoder works for H265/HEVC with HW acceleration, and put the frames on CUDA tensor"""
self
.
_test_decode
(
"hevc_cuvid"
,
"testsrc.hevc"
,
144
,
256
,
300
,
hw_accel
=
"cuda:0"
,
dtype
=
torch
.
int16
)
@
skipIfNoHWAccel
(
"hevc_cuvid"
)
def
test_hevc_cuvid_hw_accel_resize
(
self
):
"""GPU decoder works for H265/HEVC with HW acceleration and resize option"""
w
,
h
=
128
,
64
self
.
_test_decode
(
"hevc_cuvid"
,
"testsrc.hevc"
,
h
,
w
,
300
,
hw_accel
=
"cuda:0"
,
dtype
=
torch
.
int16
,
decoder_option
=
{
"resize"
:
f
"
{
w
}
x
{
h
}
"
},
)
@
skipIfNoHWAccel
(
"hevc_cuvid"
)
def
test_hevc_cuvid_hw_accel_crop
(
self
):
"""GPU decoder works for H265/HEVC with HW acceleration and crop option"""
top
,
bottom
,
left
,
right
=
3
,
5
,
7
,
9
self
.
_test_decode
(
"hevc_cuvid"
,
"testsrc.hevc"
,
136
,
240
,
300
,
hw_accel
=
"cuda:0"
,
dtype
=
torch
.
int16
,
decoder_option
=
{
"crop"
:
f
"
{
top
}
x
{
bottom
}
x
{
left
}
x
{
right
}
"
},
)
@
skipIfNoHWAccel
(
"h264_cuvid"
)
# Disabled in CI: https://github.com/pytorch/audio/issues/3376
@
disabledInCI
class
FilterGraphWithCudaAccel
(
TorchaudioTestCase
):
def
test_sclae_cuda_change_size
(
self
):
"""scale_cuda filter can be used when HW accel is on"""
src
=
get_asset_path
(
"nasa_13013.mp4"
)
r
=
StreamReader
(
src
)
r
.
add_video_stream
(
10
,
decoder
=
"h264_cuvid"
,
hw_accel
=
"cuda"
,
filter_desc
=
"scale_cuda=iw/2:ih/2"
)
num_frames
=
0
for
(
chunk
,)
in
r
.
stream
():
self
.
assertEqual
(
chunk
.
device
,
torch
.
device
(
"cuda:0"
))
self
.
assertEqual
(
chunk
.
dtype
,
torch
.
uint8
)
self
.
assertEqual
(
chunk
.
shape
,
torch
.
Size
([
10
,
3
,
135
,
240
]))
num_frames
+=
chunk
.
size
(
0
)
assert
num_frames
==
390
def
test_scale_cuda_format
(
self
):
"""yuv444p format conversion should work"""
src
=
get_asset_path
(
"nasa_13013.mp4"
)
r
=
StreamReader
(
src
)
r
.
add_video_stream
(
10
,
decoder
=
"h264_cuvid"
,
hw_accel
=
"cuda"
,
filter_desc
=
"scale_cuda=format=yuv444p"
)
num_frames
=
0
for
(
chunk
,)
in
r
.
stream
():
self
.
assertEqual
(
chunk
.
device
,
torch
.
device
(
"cuda:0"
))
self
.
assertEqual
(
chunk
.
dtype
,
torch
.
uint8
)
self
.
assertEqual
(
chunk
.
shape
,
torch
.
Size
([
10
,
3
,
270
,
480
]))
num_frames
+=
chunk
.
size
(
0
)
assert
num_frames
==
390
test/torchaudio_unittest/io/stream_writer_test.py
View file @
ffeba11a
import
io
import
math
import
torch
import
torch
import
torchaudio
import
torchaudio
from
parameterized
import
parameterized
,
parameterized_class
from
parameterized
import
parameterized
,
parameterized_class
from
torchaudio_unittest.common_utils
import
(
from
torchaudio_unittest.common_utils
import
(
get_asset_path
,
get_asset_path
,
get_sinusoid
,
is_ffmpeg_available
,
is_ffmpeg_available
,
nested_params
,
nested_params
,
rgb_to_yuv_ccir
,
rgb_to_yuv_ccir
,
...
@@ -13,8 +17,10 @@ from torchaudio_unittest.common_utils import (
...
@@ -13,8 +17,10 @@ from torchaudio_unittest.common_utils import (
TorchaudioTestCase
,
TorchaudioTestCase
,
)
)
from
.common
import
lt42
if
is_ffmpeg_available
():
if
is_ffmpeg_available
():
from
torchaudio.io
import
StreamReader
,
StreamWriter
from
torchaudio.io
import
CodecConfig
,
StreamReader
,
StreamWriter
def
get_audio_chunk
(
fmt
,
sample_rate
,
num_channels
):
def
get_audio_chunk
(
fmt
,
sample_rate
,
num_channels
):
...
@@ -87,9 +93,21 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
...
@@ -87,9 +93,21 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
def
get_dst
(
self
,
path
):
def
get_dst
(
self
,
path
):
return
super
().
get_dst
(
self
.
get_temp_path
(
path
))
return
super
().
get_dst
(
self
.
get_temp_path
(
path
))
def
get_buf
(
self
,
path
):
def
test_unopened_error
(
self
):
with
open
(
self
.
get_temp_path
(
path
),
"rb"
)
as
fileobj
:
"""If dst is not opened when attempting to write data, runtime error should be raised"""
return
fileobj
.
read
()
path
=
self
.
get_dst
(
"test.mp4"
)
s
=
StreamWriter
(
path
,
format
=
"mp4"
)
s
.
set_metadata
(
metadata
=
{
"artist"
:
"torchaudio"
,
"title"
:
self
.
id
()})
s
.
add_audio_stream
(
sample_rate
=
16000
,
num_channels
=
2
)
s
.
add_video_stream
(
frame_rate
=
30
,
width
=
16
,
height
=
16
)
dummy
=
torch
.
zeros
((
3
,
2
))
with
self
.
assertRaises
(
RuntimeError
):
s
.
write_audio_chunk
(
0
,
dummy
)
dummy
=
torch
.
zeros
((
3
,
3
,
16
,
16
))
with
self
.
assertRaises
(
RuntimeError
):
s
.
write_video_chunk
(
1
,
dummy
)
@
skipIfNoModule
(
"tinytag"
)
@
skipIfNoModule
(
"tinytag"
)
def
test_metadata_overwrite
(
self
):
def
test_metadata_overwrite
(
self
):
...
@@ -135,21 +153,26 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
...
@@ -135,21 +153,26 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
@
parameterized
.
expand
(
@
parameterized
.
expand
(
[
[
(
"mp3"
,
8000
,
1
,
"s32p"
,
None
),
(
"mp3"
,
8000
,
1
,
None
,
"s32p"
,
None
),
(
"mp3"
,
16000
,
2
,
"fltp"
,
None
),
(
"mp3"
,
16000
,
2
,
None
,
"fltp"
,
None
),
(
"mp3"
,
44100
,
1
,
"s16p"
,
{
"abr"
:
"true"
}),
(
"mp3"
,
44100
,
1
,
None
,
"s16p"
,
{
"abr"
:
"true"
}),
(
"flac"
,
8000
,
1
,
"s16"
,
None
),
(
"flac"
,
8000
,
1
,
None
,
"s16"
,
None
),
(
"flac"
,
16000
,
2
,
"s32"
,
None
),
(
"flac"
,
16000
,
2
,
None
,
"s32"
,
None
),
(
"opus"
,
48000
,
2
,
None
,
{
"strict"
:
"experimental"
}),
(
"opus"
,
48000
,
2
,
"opus"
,
None
,
None
),
(
"adts"
,
8000
,
1
,
"fltp"
,
None
),
# AAC format
(
"ogg"
,
48000
,
2
,
"vorbis"
,
None
,
None
),
(
"adts"
,
8000
,
1
,
None
,
"fltp"
,
None
),
# AAC format
]
]
)
)
def
test_valid_audio_muxer_and_codecs
(
self
,
ext
,
sample_rate
,
num_channels
,
encoder_format
,
encoder_option
):
def
test_valid_audio_muxer_and_codecs
(
self
,
ext
,
sample_rate
,
num_channels
,
encoder
,
encoder_format
,
encoder_option
):
"""Tensor of various dtypes can be saved as given format."""
"""Tensor of various dtypes can be saved as given format."""
path
=
self
.
get_dst
(
f
"test.
{
ext
}
"
)
path
=
self
.
get_dst
(
f
"test.
{
ext
}
"
)
s
=
StreamWriter
(
path
,
format
=
ext
)
s
=
StreamWriter
(
path
,
format
=
ext
)
s
.
set_metadata
(
metadata
=
{
"artist"
:
"torchaudio"
,
"title"
:
self
.
id
()})
s
.
set_metadata
(
metadata
=
{
"artist"
:
"torchaudio"
,
"title"
:
self
.
id
()})
s
.
add_audio_stream
(
sample_rate
,
num_channels
,
encoder_option
=
encoder_option
,
encoder_format
=
encoder_format
)
s
.
add_audio_stream
(
sample_rate
,
num_channels
,
encoder
=
encoder
,
encoder_option
=
encoder_option
,
encoder_format
=
encoder_format
)
chunk
=
get_audio_chunk
(
"flt"
,
sample_rate
,
num_channels
)
chunk
=
get_audio_chunk
(
"flt"
,
sample_rate
,
num_channels
)
with
s
.
open
():
with
s
.
open
():
...
@@ -202,6 +225,19 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
...
@@ -202,6 +225,19 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
s
.
write_audio_chunk
(
0
,
audio
)
s
.
write_audio_chunk
(
0
,
audio
)
s
.
write_video_chunk
(
1
,
video
)
s
.
write_video_chunk
(
1
,
video
)
@
skipIfNoFFmpeg
class
StreamWriterCorrectnessTest
(
TempDirMixin
,
TorchaudioTestCase
):
@
classmethod
def
setUpClass
(
cls
):
super
().
setUpClass
()
torchaudio
.
utils
.
ffmpeg_utils
.
set_log_level
(
32
)
@
classmethod
def
tearDownClass
(
cls
):
torchaudio
.
utils
.
ffmpeg_utils
.
set_log_level
(
8
)
super
().
tearDownClass
()
@
nested_params
(
@
nested_params
(
[
[
(
"gray8"
,
"gray8"
),
(
"gray8"
,
"gray8"
),
...
@@ -227,16 +263,16 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
...
@@ -227,16 +263,16 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
chunk
=
torch
.
randint
(
low
=
0
,
high
=
255
,
size
=
src_size
,
dtype
=
torch
.
uint8
)
chunk
=
torch
.
randint
(
low
=
0
,
high
=
255
,
size
=
src_size
,
dtype
=
torch
.
uint8
)
# Write data
# Write data
dst
=
self
.
get_
dst
(
filename
)
dst
=
self
.
get_
temp_path
(
filename
)
s
=
StreamWriter
(
dst
,
format
=
"rawvideo"
)
s
=
StreamWriter
(
dst
,
format
=
"rawvideo"
)
s
.
add_video_stream
(
frame_rate
,
width
,
height
,
format
=
src_fmt
,
encoder_format
=
encoder_fmt
)
s
.
add_video_stream
(
frame_rate
,
width
,
height
,
format
=
src_fmt
,
encoder_format
=
encoder_fmt
)
with
s
.
open
():
with
s
.
open
():
s
.
write_video_chunk
(
0
,
chunk
)
s
.
write_video_chunk
(
0
,
chunk
)
# Fetch the written data
# Fetch the written data
if
self
.
test_
fileobj
:
with
open
(
dst
,
"rb"
)
as
fileobj
:
dst
.
flush
()
buf
=
fileobj
.
read
()
buf
=
self
.
get_buf
(
filename
)
result
=
torch
.
frombuffer
(
buf
,
dtype
=
torch
.
uint8
)
result
=
torch
.
frombuffer
(
buf
,
dtype
=
torch
.
uint8
)
if
encoder_fmt
.
endswith
(
"p"
):
if
encoder_fmt
.
endswith
(
"p"
):
result
=
result
.
reshape
(
src_size
)
result
=
result
.
reshape
(
src_size
)
...
@@ -261,14 +297,12 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
...
@@ -261,14 +297,12 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
h
,
w
=
resolution
h
,
w
=
resolution
# Write data
# Write data
dst
=
self
.
get_
dst
(
filename
)
dst
=
self
.
get_
temp_path
(
filename
)
s
=
torchaudio
.
io
.
StreamWriter
(
dst
=
dst
,
format
=
ext
)
s
=
torchaudio
.
io
.
StreamWriter
(
dst
=
dst
,
format
=
ext
)
s
.
add_video_stream
(
frame_rate
=
framerate
,
height
=
h
,
width
=
w
,
format
=
format
)
s
.
add_video_stream
(
frame_rate
=
framerate
,
height
=
h
,
width
=
w
,
format
=
format
)
chunk
=
torch
.
stack
([
torch
.
full
((
3
,
h
,
w
),
i
,
dtype
=
torch
.
uint8
)
for
i
in
torch
.
linspace
(
0
,
255
,
256
)])
chunk
=
torch
.
stack
([
torch
.
full
((
3
,
h
,
w
),
i
,
dtype
=
torch
.
uint8
)
for
i
in
torch
.
linspace
(
0
,
255
,
256
)])
with
s
.
open
():
with
s
.
open
():
s
.
write_video_chunk
(
0
,
chunk
)
s
.
write_video_chunk
(
0
,
chunk
)
if
self
.
test_fileobj
:
dst
.
flush
()
# Load data
# Load data
s
=
torchaudio
.
io
.
StreamReader
(
src
=
self
.
get_temp_path
(
filename
))
s
=
torchaudio
.
io
.
StreamReader
(
src
=
self
.
get_temp_path
(
filename
))
...
@@ -293,30 +327,54 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
...
@@ -293,30 +327,54 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
pass
pass
@
nested_params
(
@
nested_params
(
[
"wav"
,
"mp3"
,
"flac"
],
[
"wav"
,
"flac"
],
[
8000
,
16000
,
44100
],
[
8000
,
16000
,
44100
],
[
1
,
2
],
[
1
,
2
],
)
)
def
test_audio_num_frames
(
self
,
ext
,
sample_rate
,
num_channels
):
def
test_audio_num_frames
_lossless
(
self
,
ext
,
sample_rate
,
num_channels
):
""""""
"""
Lossless format preserves the data
"""
filename
=
f
"test.
{
ext
}
"
filename
=
f
"test.
{
ext
}
"
data
=
get_sinusoid
(
sample_rate
=
sample_rate
,
n_channels
=
num_channels
,
dtype
=
"int16"
,
channels_first
=
False
)
# Write data
# Write data
dst
=
self
.
get_
dst
(
filename
)
dst
=
self
.
get_
temp_path
(
filename
)
s
=
torchaudio
.
io
.
StreamWriter
(
dst
=
dst
,
format
=
ext
)
s
=
torchaudio
.
io
.
StreamWriter
(
dst
=
dst
,
format
=
ext
)
s
.
add_audio_stream
(
sample_rate
=
sample_rate
,
num_channels
=
num_channels
)
s
.
add_audio_stream
(
sample_rate
=
sample_rate
,
num_channels
=
num_channels
,
format
=
"s16"
)
with
s
.
open
():
s
.
write_audio_chunk
(
0
,
data
)
freq
=
300
# Load data
duration
=
60
s
=
torchaudio
.
io
.
StreamReader
(
src
=
self
.
get_temp_path
(
filename
))
theta
=
torch
.
linspace
(
0
,
freq
*
2
*
3.14
*
duration
,
sample_rate
*
duration
)
s
.
add_audio_stream
(
-
1
)
if
num_channels
==
1
:
s
.
process_all_packets
()
chunk
=
torch
.
sin
(
theta
).
unsqueeze
(
-
1
)
(
saved
,)
=
s
.
pop_chunks
()
else
:
chunk
=
torch
.
stack
([
torch
.
sin
(
theta
),
torch
.
cos
(
theta
)],
dim
=-
1
)
self
.
assertEqual
(
saved
,
data
)
@
parameterized
.
expand
(
[
(
"mp3"
,
1
,
8000
),
(
"mp3"
,
1
,
16000
),
(
"mp3"
,
1
,
44100
),
(
"mp3"
,
2
,
8000
),
(
"mp3"
,
2
,
16000
),
(
"mp3"
,
2
,
44100
),
(
"opus"
,
1
,
48000
),
]
)
def
test_audio_num_frames_lossy
(
self
,
ext
,
num_channels
,
sample_rate
):
"""Saving audio preserves the number of channels and frames"""
filename
=
f
"test.
{
ext
}
"
data
=
get_sinusoid
(
sample_rate
=
sample_rate
,
n_channels
=
num_channels
,
channels_first
=
False
)
# Write data
dst
=
self
.
get_temp_path
(
filename
)
s
=
torchaudio
.
io
.
StreamWriter
(
dst
=
dst
,
format
=
ext
)
s
.
add_audio_stream
(
sample_rate
=
sample_rate
,
num_channels
=
num_channels
)
with
s
.
open
():
with
s
.
open
():
s
.
write_audio_chunk
(
0
,
chunk
)
s
.
write_audio_chunk
(
0
,
data
)
if
self
.
test_fileobj
:
dst
.
flush
()
# Load data
# Load data
s
=
torchaudio
.
io
.
StreamReader
(
src
=
self
.
get_temp_path
(
filename
))
s
=
torchaudio
.
io
.
StreamReader
(
src
=
self
.
get_temp_path
(
filename
))
...
@@ -324,9 +382,28 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
...
@@ -324,9 +382,28 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
s
.
process_all_packets
()
s
.
process_all_packets
()
(
saved
,)
=
s
.
pop_chunks
()
(
saved
,)
=
s
.
pop_chunks
()
assert
saved
.
shape
==
chunk
.
shape
# On 4.1 OPUS produces 48312 samples (extra 312)
if
format
in
[
"wav"
,
"flac"
]:
# this has been fixed on 4.2+
self
.
assertEqual
(
saved
,
chunk
)
# TODO: issue warning if on 4.1?
if
ext
==
"opus"
and
lt42
():
return
self
.
assertEqual
(
saved
.
shape
,
data
.
shape
)
def
test_g722_sample_rate
(
self
):
"""Encoding G.722 properly converts sample rate to 16k"""
filename
=
"test.g722"
sample_rate
=
41000
data
=
get_sinusoid
(
sample_rate
=
sample_rate
,
n_channels
=
1
,
channels_first
=
False
)
# write data
dst
=
self
.
get_temp_path
(
filename
)
w
=
StreamWriter
(
dst
,
format
=
"g722"
)
w
.
add_audio_stream
(
sample_rate
=
sample_rate
,
num_channels
=
1
)
with
w
.
open
():
w
.
write_audio_chunk
(
0
,
data
)
r
=
StreamReader
(
src
=
self
.
get_temp_path
(
filename
))
self
.
assertEqual
(
r
.
get_src_stream_info
(
0
).
sample_rate
,
16000
)
def
test_preserve_fps
(
self
):
def
test_preserve_fps
(
self
):
"""Decimal point frame rate is properly saved
"""Decimal point frame rate is properly saved
...
@@ -339,16 +416,346 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
...
@@ -339,16 +416,346 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
width
,
height
=
96
,
128
width
,
height
=
96
,
128
# Write data
# Write data
dst
=
self
.
get_
dst
(
filename
)
dst
=
self
.
get_
temp_path
(
filename
)
writer
=
torchaudio
.
io
.
StreamWriter
(
dst
=
dst
,
format
=
ext
)
writer
=
torchaudio
.
io
.
StreamWriter
(
dst
=
dst
,
format
=
ext
)
writer
.
add_video_stream
(
frame_rate
=
frame_rate
,
width
=
width
,
height
=
height
)
writer
.
add_video_stream
(
frame_rate
=
frame_rate
,
width
=
width
,
height
=
height
)
video
=
torch
.
randint
(
256
,
(
90
,
3
,
height
,
width
),
dtype
=
torch
.
uint8
)
video
=
torch
.
randint
(
256
,
(
90
,
3
,
height
,
width
),
dtype
=
torch
.
uint8
)
with
writer
.
open
():
with
writer
.
open
():
writer
.
write_video_chunk
(
0
,
video
)
writer
.
write_video_chunk
(
0
,
video
)
if
self
.
test_fileobj
:
dst
.
flush
()
# Load data
# Load data
reader
=
torchaudio
.
io
.
StreamReader
(
src
=
self
.
get_temp_path
(
filename
))
reader
=
torchaudio
.
io
.
StreamReader
(
src
=
self
.
get_temp_path
(
filename
))
assert
reader
.
get_src_stream_info
(
0
).
frame_rate
==
frame_rate
assert
reader
.
get_src_stream_info
(
0
).
frame_rate
==
frame_rate
def
test_video_pts_increment
(
self
):
"""PTS values increment by the inverse of frame rate"""
ext
=
"mp4"
num_frames
=
256
filename
=
f
"test.
{
ext
}
"
frame_rate
=
5000
/
167
width
,
height
=
96
,
128
# Write data
dst
=
self
.
get_temp_path
(
filename
)
writer
=
torchaudio
.
io
.
StreamWriter
(
dst
=
dst
,
format
=
ext
)
writer
.
add_video_stream
(
frame_rate
=
frame_rate
,
width
=
width
,
height
=
height
)
video
=
torch
.
randint
(
256
,
(
num_frames
,
3
,
height
,
width
),
dtype
=
torch
.
uint8
)
with
writer
.
open
():
writer
.
write_video_chunk
(
0
,
video
)
reader
=
torchaudio
.
io
.
StreamReader
(
src
=
self
.
get_temp_path
(
filename
))
reader
.
add_video_stream
(
1
)
pts
=
[
chunk
.
pts
for
(
chunk
,)
in
reader
.
stream
()]
assert
len
(
pts
)
==
num_frames
for
i
,
val
in
enumerate
(
pts
):
expected
=
i
/
frame_rate
assert
abs
(
val
-
expected
)
<
1e-10
def
test_audio_pts_increment
(
self
):
"""PTS values increment by the inverse of sample rate"""
ext
=
"wav"
filename
=
f
"test.
{
ext
}
"
sample_rate
=
8000
num_channels
=
2
# Write data
dst
=
self
.
get_temp_path
(
filename
)
writer
=
torchaudio
.
io
.
StreamWriter
(
dst
=
dst
,
format
=
ext
)
writer
.
add_audio_stream
(
sample_rate
=
sample_rate
,
num_channels
=
num_channels
)
audio
=
get_sinusoid
(
sample_rate
=
sample_rate
,
n_channels
=
num_channels
,
channels_first
=
False
)
num_frames
=
audio
.
size
(
0
)
with
writer
.
open
():
writer
.
write_audio_chunk
(
0
,
audio
)
reader
=
torchaudio
.
io
.
StreamReader
(
src
=
self
.
get_temp_path
(
filename
))
frames_per_chunk
=
sample_rate
//
4
reader
.
add_audio_stream
(
frames_per_chunk
,
-
1
)
chunks
=
[
chunk
for
(
chunk
,)
in
reader
.
stream
()]
expected
=
num_frames
//
(
frames_per_chunk
)
assert
len
(
chunks
)
==
expected
,
f
"Expected
{
expected
}
elements. Found
{
len
(
chunks
)
}
"
num_samples
=
0
for
chunk
in
chunks
:
expected
=
num_samples
/
sample_rate
num_samples
+=
chunk
.
size
(
0
)
print
(
chunk
.
pts
,
expected
)
assert
abs
(
chunk
.
pts
-
expected
)
<
1e-10
@
parameterized
.
expand
(
[
(
10
,
100
),
(
15
,
150
),
(
24
,
240
),
(
25
,
200
),
(
30
,
300
),
(
50
,
500
),
(
60
,
600
),
# PTS value conversion involves float <-> int conversion, which can
# introduce rounding error.
# This test is a spot-check for popular 29.97 Hz
(
30000
/
1001
,
10010
),
]
)
def
test_video_pts_overwrite
(
self
,
frame_rate
,
num_frames
):
"""Can overwrite PTS"""
ext
=
"mp4"
filename
=
f
"test.
{
ext
}
"
width
,
height
=
8
,
8
# Write data
dst
=
self
.
get_temp_path
(
filename
)
writer
=
torchaudio
.
io
.
StreamWriter
(
dst
=
dst
,
format
=
ext
)
writer
.
add_video_stream
(
frame_rate
=
frame_rate
,
width
=
width
,
height
=
height
)
video
=
torch
.
zeros
((
1
,
3
,
height
,
width
),
dtype
=
torch
.
uint8
)
reference_pts
=
[]
with
writer
.
open
():
for
i
in
range
(
num_frames
):
pts
=
i
/
frame_rate
reference_pts
.
append
(
pts
)
writer
.
write_video_chunk
(
0
,
video
,
pts
)
reader
=
torchaudio
.
io
.
StreamReader
(
src
=
self
.
get_temp_path
(
filename
))
reader
.
add_video_stream
(
1
)
pts
=
[
chunk
.
pts
for
(
chunk
,)
in
reader
.
stream
()]
assert
len
(
pts
)
==
len
(
reference_pts
)
for
val
,
ref
in
zip
(
pts
,
reference_pts
):
# torch provides isclose, but we don't know if converting floats to tensor
# could introduce a descrepancy, so we compare floats and use math.isclose
# for that.
assert
math
.
isclose
(
val
,
ref
)
def
test_codec_config
(
self
):
"""Can successfully set configuration and write audio."""
ext
=
"mp3"
filename
=
f
"test.
{
ext
}
"
sample_rate
=
44100
num_channels
=
2
# Write data
dst
=
self
.
get_temp_path
(
filename
)
writer
=
torchaudio
.
io
.
StreamWriter
(
dst
=
dst
,
format
=
ext
)
codec_config
=
CodecConfig
(
bit_rate
=
198_000
,
compression_level
=
3
)
writer
.
add_audio_stream
(
sample_rate
=
sample_rate
,
num_channels
=
num_channels
,
codec_config
=
codec_config
)
audio
=
torch
.
zeros
((
8000
,
2
))
with
writer
.
open
():
writer
.
write_audio_chunk
(
0
,
audio
)
def
test_codec_config_bit_rate_output
(
self
):
"""Increasing the specified bit rate yields a larger encoded output."""
ext
=
"mp3"
sample_rate
=
44100
num_channels
=
2
audio
=
torch
.
rand
((
8000
,
num_channels
))
def
write_audio
(
buffer
,
bit_rate
):
writer
=
torchaudio
.
io
.
StreamWriter
(
dst
=
buffer
,
format
=
ext
)
writer
.
add_audio_stream
(
sample_rate
=
sample_rate
,
num_channels
=
num_channels
,
codec_config
=
CodecConfig
(
bit_rate
=
bit_rate
),
)
with
writer
.
open
():
writer
.
write_audio_chunk
(
0
,
audio
)
dst
=
io
.
BytesIO
()
write_audio
(
dst
,
198_000
)
out0_size
=
dst
.
tell
()
dst
=
io
.
BytesIO
()
write_audio
(
dst
,
320_000
)
out1_size
=
dst
.
tell
()
self
.
assertGreater
(
out1_size
,
out0_size
)
def
test_filter_graph_audio
(
self
):
"""Can apply additional effect with filter graph"""
sample_rate
=
8000
num_channels
=
2
ext
=
"wav"
filename
=
f
"test.
{
ext
}
"
original
=
get_audio_chunk
(
"s16"
,
num_channels
=
num_channels
,
sample_rate
=
sample_rate
)
dst
=
self
.
get_temp_path
(
filename
)
w
=
StreamWriter
(
dst
,
format
=
ext
)
w
.
add_audio_stream
(
sample_rate
=
8000
,
num_channels
=
num_channels
,
filter_desc
=
"areverse"
,
format
=
"s16"
)
with
w
.
open
():
w
.
write_audio_chunk
(
0
,
original
)
reader
=
torchaudio
.
io
.
StreamReader
(
src
=
self
.
get_temp_path
(
filename
))
reader
.
add_audio_stream
(
-
1
)
reader
.
process_all_packets
()
(
output
,)
=
reader
.
pop_chunks
()
self
.
assertEqual
(
output
,
original
.
flip
(
0
))
def
test_filter_graph_video
(
self
):
"""Can apply additional effect with filter graph"""
src_rate
=
30
num_frames
,
width
,
height
=
400
,
160
,
90
filter_desc
=
"framestep=2"
enc_rate
=
15
ext
=
"mp4"
filename
=
f
"test.
{
ext
}
"
original
=
torch
.
zeros
((
num_frames
,
3
,
height
,
width
),
dtype
=
torch
.
uint8
)
dst
=
self
.
get_temp_path
(
filename
)
w
=
StreamWriter
(
dst
,
format
=
ext
)
w
.
add_video_stream
(
frame_rate
=
src_rate
,
format
=
"rgb24"
,
height
=
height
,
width
=
width
,
filter_desc
=
filter_desc
,
encoder_format
=
"yuv420p"
,
encoder_frame_rate
=
enc_rate
,
)
with
w
.
open
():
w
.
write_video_chunk
(
0
,
original
)
reader
=
torchaudio
.
io
.
StreamReader
(
src
=
self
.
get_temp_path
(
filename
))
reader
.
add_video_stream
(
-
1
)
reader
.
process_all_packets
()
(
output
,)
=
reader
.
pop_chunks
()
self
.
assertEqual
(
output
.
shape
,
[
num_frames
//
2
,
3
,
height
,
width
])
@
parameterized
.
expand
(
[
(
"wav"
,
"pcm_s16le"
,
8000
,
16000
,
1
,
2
),
(
"wav"
,
"pcm_s16le"
,
8000
,
16000
,
2
,
1
),
(
"wav"
,
"pcm_s16le"
,
8000
,
16000
,
2
,
4
),
(
"wav"
,
"pcm_s16le"
,
16000
,
8000
,
1
,
2
),
(
"wav"
,
"pcm_s16le"
,
16000
,
8000
,
2
,
1
),
(
"wav"
,
"pcm_s16le"
,
16000
,
8000
,
2
,
4
),
(
"wav"
,
"pcm_f32le"
,
8000
,
16000
,
1
,
2
),
(
"wav"
,
"pcm_f32le"
,
8000
,
16000
,
2
,
1
),
(
"wav"
,
"pcm_f32le"
,
8000
,
16000
,
2
,
4
),
(
"wav"
,
"pcm_f32le"
,
16000
,
8000
,
1
,
2
),
(
"wav"
,
"pcm_f32le"
,
16000
,
8000
,
2
,
1
),
(
"wav"
,
"pcm_f32le"
,
16000
,
8000
,
2
,
4
),
(
"ogg"
,
"opus"
,
8000
,
48000
,
1
,
2
),
(
"ogg"
,
"opus"
,
8000
,
48000
,
2
,
1
),
(
"ogg"
,
"flac"
,
8000
,
41000
,
1
,
2
),
(
"ogg"
,
"flac"
,
8000
,
41000
,
2
,
1
),
(
"ogg"
,
"vorbis"
,
16000
,
8000
,
1
,
2
),
(
"ogg"
,
"vorbis"
,
16000
,
8000
,
4
,
2
),
]
)
def
test_change_audio_encoder_spec
(
self
,
ext
,
encoder
,
src_sr
,
enc_sr
,
src_num_channels
,
enc_num_channels
):
"""Can change sample rate and channels on-the-fly"""
filename
=
f
"test.
{
ext
}
"
original
=
get_sinusoid
(
sample_rate
=
src_sr
,
n_channels
=
src_num_channels
,
channels_first
=
False
,
duration
=
0.1
)
dst
=
self
.
get_temp_path
(
filename
)
w
=
StreamWriter
(
dst
,
format
=
ext
)
w
.
add_audio_stream
(
sample_rate
=
src_sr
,
format
=
"flt"
,
num_channels
=
src_num_channels
,
encoder
=
encoder
,
encoder_sample_rate
=
enc_sr
,
encoder_num_channels
=
enc_num_channels
,
)
with
w
.
open
():
w
.
write_audio_chunk
(
0
,
original
)
# check
reader
=
torchaudio
.
io
.
StreamReader
(
src
=
self
.
get_temp_path
(
filename
))
i
=
reader
.
get_src_stream_info
(
0
)
self
.
assertEqual
(
i
.
sample_rate
,
enc_sr
)
self
.
assertEqual
(
i
.
num_channels
,
enc_num_channels
)
@
parameterized
.
expand
(
[
# opus only supports 48kHz
(
"ogg"
,
"opus"
,
8000
,
48000
,
1
,
1
),
(
"ogg"
,
"opus"
,
16000
,
48000
,
2
,
2
),
# vorbis only supports 2 channels
(
"ogg"
,
"vorbis"
,
16000
,
16000
,
1
,
2
),
(
"ogg"
,
"vorbis"
,
16000
,
16000
,
2
,
2
),
(
"ogg"
,
"vorbis"
,
16000
,
16000
,
4
,
2
),
]
)
def
test_change_encoder_spec_default
(
self
,
ext
,
encoder
,
src_sr
,
expected_sr
,
src_num_channels
,
expected_num_channels
):
"""If input rate/channels are not supported, encoder picks supported one automatically."""
filename
=
f
"test.
{
ext
}
"
original
=
get_sinusoid
(
sample_rate
=
src_sr
,
n_channels
=
src_num_channels
,
channels_first
=
False
,
duration
=
0.1
)
dst
=
self
.
get_temp_path
(
filename
)
w
=
StreamWriter
(
dst
,
format
=
ext
)
w
.
add_audio_stream
(
sample_rate
=
src_sr
,
format
=
"flt"
,
num_channels
=
src_num_channels
,
encoder
=
encoder
,
)
with
w
.
open
():
w
.
write_audio_chunk
(
0
,
original
)
# check
reader
=
torchaudio
.
io
.
StreamReader
(
src
=
self
.
get_temp_path
(
filename
))
i
=
reader
.
get_src_stream_info
(
0
)
self
.
assertEqual
(
i
.
sample_rate
,
expected_sr
)
self
.
assertEqual
(
i
.
num_channels
,
expected_num_channels
)
@
parameterized
.
expand
(
[
(
"mp4"
,
None
,
10
,
30
,
(
100
,
160
),
(
200
,
320
)),
(
"mp4"
,
None
,
10
,
30
,
(
100
,
160
),
(
50
,
80
)),
(
"mp4"
,
None
,
30
,
10
,
(
100
,
160
),
(
200
,
320
)),
(
"mp4"
,
None
,
30
,
10
,
(
100
,
160
),
(
50
,
80
)),
]
)
def
test_change_video_encoder_spec
(
self
,
ext
,
encoder
,
src_rate
,
enc_rate
,
src_size
,
enc_size
):
"""Can change the frame rate and image size on-the-fly"""
width
,
height
=
src_size
enc_width
,
enc_height
=
enc_size
ext
=
"mp4"
filename
=
f
"test.
{
ext
}
"
num_frames
=
256
original
=
torch
.
zeros
((
num_frames
,
3
,
height
,
width
),
dtype
=
torch
.
uint8
)
dst
=
self
.
get_temp_path
(
filename
)
w
=
StreamWriter
(
dst
,
format
=
ext
)
w
.
add_video_stream
(
frame_rate
=
src_rate
,
format
=
"rgb24"
,
height
=
height
,
width
=
width
,
encoder_format
=
"yuv420p"
,
encoder_frame_rate
=
enc_rate
,
encoder_width
=
enc_width
,
encoder_height
=
enc_height
,
)
with
w
.
open
():
w
.
write_video_chunk
(
0
,
original
)
# check
reader
=
torchaudio
.
io
.
StreamReader
(
src
=
self
.
get_temp_path
(
filename
))
i
=
reader
.
get_src_stream_info
(
0
)
self
.
assertEqual
(
i
.
frame_rate
,
enc_rate
)
self
.
assertEqual
(
i
.
width
,
enc_width
)
self
.
assertEqual
(
i
.
height
,
enc_height
)
test/torchaudio_unittest/models/decoder/ctc_decoder_test.py
View file @
ffeba11a
...
@@ -169,3 +169,19 @@ class CTCDecoderTest(TempDirMixin, TorchaudioTestCase):
...
@@ -169,3 +169,19 @@ class CTCDecoderTest(TempDirMixin, TorchaudioTestCase):
expected_tokens
=
[
"|"
,
"f"
,
"|"
,
"o"
,
"a"
]
expected_tokens
=
[
"|"
,
"f"
,
"|"
,
"o"
,
"a"
]
self
.
assertEqual
(
tokens
,
expected_tokens
)
self
.
assertEqual
(
tokens
,
expected_tokens
)
def
test_lm_lifecycle
(
self
):
"""Passing lm without assiging it to a vaiable won't cause runtime error
https://github.com/pytorch/audio/issues/3218
"""
from
torchaudio.models.decoder
import
ctc_decoder
from
.ctc_decoder_utils
import
CustomZeroLM
decoder
=
ctc_decoder
(
lexicon
=
get_asset_path
(
"decoder/lexicon.txt"
),
tokens
=
get_asset_path
(
"decoder/tokens.txt"
),
lm
=
CustomZeroLM
(),
)
decoder
(
torch
.
zeros
((
1
,
3
,
NUM_TOKENS
),
dtype
=
torch
.
float32
))
test/torchaudio_unittest/models/decoder/cuda_ctc_decoder_test.py
0 → 100644
View file @
ffeba11a
import
torch
from
torchaudio_unittest.common_utils
import
(
get_asset_path
,
skipIfNoCuCtcDecoder
,
skipIfNoCuda
,
TempDirMixin
,
TorchaudioTestCase
,
)
NUM_TOKENS
=
7
@
skipIfNoCuda
@
skipIfNoCuCtcDecoder
class
CUCTCDecoderTest
(
TempDirMixin
,
TorchaudioTestCase
):
def
_get_decoder
(
self
,
tokens
=
None
,
**
kwargs
):
from
torchaudio.models.decoder
import
cuda_ctc_decoder
if
tokens
is
None
:
tokens
=
get_asset_path
(
"decoder/tokens.txt"
)
return
cuda_ctc_decoder
(
tokens
=
tokens
,
beam_size
=
5
,
**
kwargs
,
)
def
_get_emissions
(
self
):
B
,
T
,
N
=
4
,
15
,
NUM_TOKENS
emissions
=
torch
.
rand
(
B
,
T
,
N
).
cuda
()
emissions
=
torch
.
nn
.
functional
.
log_softmax
(
emissions
,
-
1
)
return
emissions
def
test_construct_basic_decoder_path
(
self
):
tokens_path
=
get_asset_path
(
"decoder/tokens.txt"
)
self
.
_get_decoder
(
tokens
=
tokens_path
)
def
test_construct_basic_decoder_tokens
(
self
):
tokens
=
[
"-"
,
"|"
,
"f"
,
"o"
,
"b"
,
"a"
,
"r"
]
self
.
_get_decoder
(
tokens
=
tokens
)
def
test_shape
(
self
):
log_probs
=
self
.
_get_emissions
()
encoder_out_lens
=
torch
.
tensor
([
15
,
14
,
13
,
12
],
dtype
=
torch
.
int32
).
cuda
()
decoder
=
self
.
_get_decoder
()
results
=
decoder
(
log_probs
,
encoder_out_lens
)
self
.
assertEqual
(
len
(
results
),
log_probs
.
shape
[
0
])
test/torchaudio_unittest/models/rnnt_decoder/rnnt_decoder_test_impl.py
View file @
ffeba11a
...
@@ -99,7 +99,7 @@ class RNNTBeamSearchTestImpl(TestBaseMixin):
...
@@ -99,7 +99,7 @@ class RNNTBeamSearchTestImpl(TestBaseMixin):
self
.
assertEqual
(
res
,
scripted_res
)
self
.
assertEqual
(
res
,
scripted_res
)
state
=
res
[
1
]
state
=
res
[
1
]
hypo
=
res
[
0
]
[
0
]
hypo
=
res
[
0
]
scripted_state
=
scripted_res
[
1
]
scripted_state
=
scripted_res
[
1
]
scripted_hypo
=
scripted_res
[
0
]
[
0
]
scripted_hypo
=
scripted_res
[
0
]
test/torchaudio_unittest/models/squim/__init__.py
0 → 100644
View file @
ffeba11a
test/torchaudio_unittest/models/squim/squim_test.py
0 → 100644
View file @
ffeba11a
import
torch
from
parameterized
import
parameterized
from
torchaudio.models
import
squim_objective_base
,
squim_subjective_base
from
torchaudio_unittest.common_utils
import
skipIfNoCuda
,
torch_script
,
TorchaudioTestCase
class
TestSquimObjective
(
TorchaudioTestCase
):
def
_smoke_test_objective
(
self
,
model
,
device
,
dtype
):
model
=
model
.
to
(
device
=
device
,
dtype
=
dtype
)
model
=
model
.
eval
()
batch_size
,
num_frames
=
3
,
16000
waveforms
=
torch
.
randn
(
batch_size
,
num_frames
,
device
=
device
,
dtype
=
dtype
)
model
(
waveforms
)
@
parameterized
.
expand
([(
torch
.
float32
,),
(
torch
.
float64
,)])
def
test_cpu_smoke_test
(
self
,
dtype
):
model
=
squim_objective_base
()
self
.
_smoke_test_objective
(
model
,
torch
.
device
(
"cpu"
),
dtype
)
@
parameterized
.
expand
([(
torch
.
float32
,),
(
torch
.
float64
,)])
@
skipIfNoCuda
def
test_cuda_smoke_test
(
self
,
dtype
):
model
=
squim_objective_base
()
self
.
_smoke_test_objective
(
model
,
torch
.
device
(
"cuda"
),
dtype
)
def
test_batch_consistency
(
self
):
model
=
squim_objective_base
()
model
.
eval
()
batch_size
,
num_frames
=
3
,
16000
waveforms
=
torch
.
randn
(
batch_size
,
num_frames
)
ref_scores
=
model
(
waveforms
)
hyp_scores
=
[
torch
.
zeros
(
batch_size
),
torch
.
zeros
(
batch_size
),
torch
.
zeros
(
batch_size
)]
for
i
in
range
(
batch_size
):
scores
=
model
(
waveforms
[
i
:
i
+
1
])
for
j
in
range
(
3
):
hyp_scores
[
j
][
i
]
=
scores
[
j
]
self
.
assertEqual
(
len
(
hyp_scores
),
len
(
ref_scores
))
for
i
in
range
(
len
(
ref_scores
)):
self
.
assertEqual
(
hyp_scores
[
i
],
ref_scores
[
i
])
def
test_torchscript_consistency
(
self
):
model
=
squim_objective_base
()
model
.
eval
()
batch_size
,
num_frames
=
3
,
16000
waveforms
=
torch
.
randn
(
batch_size
,
num_frames
)
ref_scores
=
model
(
waveforms
)
scripted
=
torch_script
(
model
)
hyp_scores
=
scripted
(
waveforms
)
self
.
assertEqual
(
len
(
hyp_scores
),
len
(
ref_scores
))
for
i
in
range
(
len
(
ref_scores
)):
self
.
assertEqual
(
hyp_scores
[
i
],
ref_scores
[
i
])
class
TestSquimSubjective
(
TorchaudioTestCase
):
def
_smoke_test_subjective
(
self
,
model
,
device
,
dtype
):
model
=
model
.
to
(
device
=
device
,
dtype
=
dtype
)
model
=
model
.
eval
()
batch_size
,
num_frames
=
3
,
16000
waveforms
=
torch
.
randn
(
batch_size
,
num_frames
,
device
=
device
,
dtype
=
dtype
)
reference
=
torch
.
randn
(
batch_size
,
num_frames
,
device
=
device
,
dtype
=
dtype
)
model
(
waveforms
,
reference
)
@
parameterized
.
expand
([(
torch
.
float32
,),
(
torch
.
float64
,)])
def
test_cpu_smoke_test
(
self
,
dtype
):
model
=
squim_subjective_base
()
self
.
_smoke_test_subjective
(
model
,
torch
.
device
(
"cpu"
),
dtype
)
@
parameterized
.
expand
([(
torch
.
float32
,),
(
torch
.
float64
,)])
@
skipIfNoCuda
def
test_cuda_smoke_test
(
self
,
dtype
):
model
=
squim_subjective_base
()
self
.
_smoke_test_subjective
(
model
,
torch
.
device
(
"cuda"
),
dtype
)
def
test_batch_consistency
(
self
):
model
=
squim_subjective_base
()
model
.
eval
()
batch_size
,
num_frames
=
3
,
16000
waveforms
=
torch
.
randn
(
batch_size
,
num_frames
)
reference
=
torch
.
randn
(
batch_size
,
num_frames
)
ref_scores
=
model
(
waveforms
,
reference
)
hyp_scores
=
[]
for
i
in
range
(
batch_size
):
scores
=
model
(
waveforms
[
i
:
i
+
1
],
reference
[
i
:
i
+
1
])
hyp_scores
.
append
(
scores
)
hyp_scores
=
torch
.
tensor
(
hyp_scores
)
self
.
assertEqual
(
hyp_scores
,
ref_scores
)
def
test_torchscript_consistency
(
self
):
model
=
squim_subjective_base
()
model
.
eval
()
batch_size
,
num_frames
=
3
,
16000
waveforms
=
torch
.
randn
(
batch_size
,
num_frames
)
reference
=
torch
.
randn
(
batch_size
,
num_frames
)
ref_scores
=
model
(
waveforms
,
reference
)
scripted
=
torch_script
(
model
)
hyp_scores
=
scripted
(
waveforms
,
reference
)
self
.
assertEqual
(
hyp_scores
,
ref_scores
)
test/torchaudio_unittest/models/tacotron2/model_test_impl.py
View file @
ffeba11a
...
@@ -42,7 +42,7 @@ class TorchscriptConsistencyMixin(TestBaseMixin):
...
@@ -42,7 +42,7 @@ class TorchscriptConsistencyMixin(TestBaseMixin):
class
Tacotron2EncoderTests
(
TorchscriptConsistencyMixin
):
class
Tacotron2EncoderTests
(
TorchscriptConsistencyMixin
):
@
skipIfPy310
#
@skipIfPy310
def
test_tacotron2_torchscript_consistency
(
self
):
def
test_tacotron2_torchscript_consistency
(
self
):
r
"""Validate the torchscript consistency of a Encoder."""
r
"""Validate the torchscript consistency of a Encoder."""
n_batch
,
n_seq
,
encoder_embedding_dim
=
16
,
64
,
512
n_batch
,
n_seq
,
encoder_embedding_dim
=
16
,
64
,
512
...
@@ -266,7 +266,7 @@ class Tacotron2Tests(TorchscriptConsistencyMixin):
...
@@ -266,7 +266,7 @@ class Tacotron2Tests(TorchscriptConsistencyMixin):
(
16
,),
(
16
,),
]
]
)
)
@
skipIfPy310
#
@skipIfPy310
def
test_tacotron2_torchscript_consistency
(
self
,
n_batch
):
def
test_tacotron2_torchscript_consistency
(
self
,
n_batch
):
r
"""Validate the torchscript consistency of a Tacotron2."""
r
"""Validate the torchscript consistency of a Tacotron2."""
n_mels
=
80
n_mels
=
80
...
@@ -335,7 +335,7 @@ class Tacotron2Tests(TorchscriptConsistencyMixin):
...
@@ -335,7 +335,7 @@ class Tacotron2Tests(TorchscriptConsistencyMixin):
(
16
,),
(
16
,),
]
]
)
)
@
skipIfPy310
#
@skipIfPy310
def
test_tacotron2_inference_torchscript_consistency
(
self
,
n_batch
):
def
test_tacotron2_inference_torchscript_consistency
(
self
,
n_batch
):
r
"""Validate the torchscript consistency of Tacotron2 inference function."""
r
"""Validate the torchscript consistency of Tacotron2 inference function."""
n_mels
=
40
n_mels
=
40
...
...
test/torchaudio_unittest/models/wav2vec2/fairseq_integration_test.py
View file @
ffeba11a
...
@@ -9,9 +9,12 @@ from torchaudio.models.wav2vec2 import (
...
@@ -9,9 +9,12 @@ from torchaudio.models.wav2vec2 import (
wav2vec2_base
,
wav2vec2_base
,
wav2vec2_large
,
wav2vec2_large
,
wav2vec2_large_lv60k
,
wav2vec2_large_lv60k
,
wav2vec2_xlsr_1b
,
wav2vec2_xlsr_2b
,
wav2vec2_xlsr_300m
,
)
)
from
torchaudio.models.wav2vec2.utils
import
import_fairseq_model
from
torchaudio.models.wav2vec2.utils
import
import_fairseq_model
from
torchaudio_unittest.common_utils
import
get_asset_path
,
skipIfNoModule
,
TorchaudioTestCase
from
torchaudio_unittest.common_utils
import
get_asset_path
,
skipIfCudaSmallMemory
,
skipIfNoModule
,
TorchaudioTestCase
def
_load_config
(
*
paths
):
def
_load_config
(
*
paths
):
...
@@ -31,6 +34,9 @@ WAV2VEC2_XLSR_53_56K = _load_config("xlsr_53_56k")
...
@@ -31,6 +34,9 @@ WAV2VEC2_XLSR_53_56K = _load_config("xlsr_53_56k")
HUBERT_BASE
=
_load_config
(
"hubert_base_ls960"
)
HUBERT_BASE
=
_load_config
(
"hubert_base_ls960"
)
HUBERT_LARGE_LL60K
=
_load_config
(
"hubert_large_ll60k"
)
HUBERT_LARGE_LL60K
=
_load_config
(
"hubert_large_ll60k"
)
HUBERT_XLARGE_LL60K
=
_load_config
(
"hubert_xtralarge_ll60k"
)
HUBERT_XLARGE_LL60K
=
_load_config
(
"hubert_xtralarge_ll60k"
)
WAV2VEC2_XLSR_300M
=
_load_config
(
"xlsr_300m"
)
WAV2VEC2_XLSR_1B
=
_load_config
(
"xlsr_1b"
)
WAV2VEC2_XLSR_2B
=
_load_config
(
"xlsr_2b"
)
# Finetuning models
# Finetuning models
WAV2VEC2_BASE_960H
=
_load_config
(
"wav2vec_small_960h"
)
WAV2VEC2_BASE_960H
=
_load_config
(
"wav2vec_small_960h"
)
WAV2VEC2_LARGE_960H
=
_load_config
(
"wav2vec_large_960h"
)
WAV2VEC2_LARGE_960H
=
_load_config
(
"wav2vec_large_960h"
)
...
@@ -50,6 +56,14 @@ WAV2VEC2_PRETRAINING_CONFIGS = parameterized.expand(
...
@@ -50,6 +56,14 @@ WAV2VEC2_PRETRAINING_CONFIGS = parameterized.expand(
],
],
name_func
=
_name_func
,
name_func
=
_name_func
,
)
)
XLSR_PRETRAINING_CONFIGS
=
parameterized
.
expand
(
[
(
WAV2VEC2_XLSR_300M
,
wav2vec2_xlsr_300m
),
(
WAV2VEC2_XLSR_1B
,
wav2vec2_xlsr_1b
),
(
WAV2VEC2_XLSR_2B
,
wav2vec2_xlsr_2b
),
],
name_func
=
_name_func
,
)
HUBERT_PRETRAINING_CONFIGS
=
parameterized
.
expand
(
HUBERT_PRETRAINING_CONFIGS
=
parameterized
.
expand
(
[
[
(
HUBERT_BASE
,
hubert_base
),
(
HUBERT_BASE
,
hubert_base
),
...
@@ -134,7 +148,24 @@ class TestFairseqIntegration(TorchaudioTestCase):
...
@@ -134,7 +148,24 @@ class TestFairseqIntegration(TorchaudioTestCase):
hyp
,
_
=
imported
.
extract_features
(
x
)
hyp
,
_
=
imported
.
extract_features
(
x
)
refs
=
original
.
extract_features
(
x
,
padding_mask
=
torch
.
zeros_like
(
x
),
layer
=-
1
)
refs
=
original
.
extract_features
(
x
,
padding_mask
=
torch
.
zeros_like
(
x
),
layer
=-
1
)
for
i
,
(
ref
,
_
)
in
enumerate
(
refs
[
"layer_results"
]):
for
i
,
(
ref
,
_
)
in
enumerate
(
refs
[
"layer_results"
]):
self
.
assertEqual
(
hyp
[
i
],
ref
.
transpose
(
0
,
1
))
self
.
assertEqual
(
hyp
[
i
],
ref
.
transpose
(
0
,
1
),
atol
=
1.5e-5
,
rtol
=
1.3e-6
)
@
XLSR_PRETRAINING_CONFIGS
@
skipIfCudaSmallMemory
def
test_import_xlsr_pretraining_model
(
self
,
config
,
factory_func
):
"""XLS-R pretraining models from fairseq can be imported and yields the same results"""
batch_size
,
num_frames
=
3
,
1024
original
=
self
.
_get_model
(
config
).
eval
()
imported
=
import_fairseq_model
(
original
).
eval
()
x
=
torch
.
randn
(
batch_size
,
num_frames
)
hyp
,
_
=
imported
.
extract_features
(
x
)
refs
=
original
.
extract_features
(
x
,
padding_mask
=
torch
.
zeros_like
(
x
),
layer
=-
1
)
for
i
,
(
ref
,
_
)
in
enumerate
(
refs
[
"layer_results"
]):
# There is one element whose difference is over 1e-5 in wav2vec2_xlsr_1b and wav2vec2_xlsr_2b.
atol
=
1.0e-05
if
factory_func
is
wav2vec2_xlsr_300m
else
1e-4
self
.
assertEqual
(
hyp
[
i
],
ref
.
transpose
(
0
,
1
),
atol
=
atol
,
rtol
=
1.3e-6
)
@
HUBERT_PRETRAINING_CONFIGS
@
HUBERT_PRETRAINING_CONFIGS
def
test_import_hubert_pretraining_model
(
self
,
config
,
factory_func
):
def
test_import_hubert_pretraining_model
(
self
,
config
,
factory_func
):
...
@@ -150,15 +181,13 @@ class TestFairseqIntegration(TorchaudioTestCase):
...
@@ -150,15 +181,13 @@ class TestFairseqIntegration(TorchaudioTestCase):
# check the last layer
# check the last layer
ref
,
_
=
original
.
extract_features
(
x
,
padding_mask
=
mask
,
output_layer
=
len
(
original
.
encoder
.
layers
))
ref
,
_
=
original
.
extract_features
(
x
,
padding_mask
=
mask
,
output_layer
=
len
(
original
.
encoder
.
layers
))
atol
=
3.0e-05
if
factory_func
is
hubert_xlarge
else
1.0e-5
self
.
assertEqual
(
hyp
[
-
1
],
ref
,
atol
=
3.0e-5
,
rtol
=
1.3e-6
)
self
.
assertEqual
(
hyp
[
-
1
],
ref
,
atol
=
atol
,
rtol
=
1.3e-6
)
# check the first layer
# check the first layer
ref
,
_
=
original
.
extract_features
(
x
,
padding_mask
=
mask
,
output_layer
=
1
)
ref
,
_
=
original
.
extract_features
(
x
,
padding_mask
=
mask
,
output_layer
=
1
)
self
.
assertEqual
(
hyp
[
0
],
ref
)
self
.
assertEqual
(
hyp
[
0
],
ref
)
@
ALL_PRETRAINING_CONFIGS
def
_test_recreate_pretraining_model
(
self
,
config
,
factory_func
):
def
test_recreate_pretraining_model
(
self
,
config
,
factory_func
):
"""Imported pretraining models can be recreated via a factory function without fairseq."""
"""Imported pretraining models can be recreated via a factory function without fairseq."""
batch_size
,
num_frames
=
3
,
1024
batch_size
,
num_frames
=
3
,
1024
...
@@ -188,6 +217,15 @@ class TestFairseqIntegration(TorchaudioTestCase):
...
@@ -188,6 +217,15 @@ class TestFairseqIntegration(TorchaudioTestCase):
self
.
assertEqual
(
ref
,
hyp
)
self
.
assertEqual
(
ref
,
hyp
)
self
.
assertEqual
(
ref_lengths
,
hyp_lengths
)
self
.
assertEqual
(
ref_lengths
,
hyp_lengths
)
@
ALL_PRETRAINING_CONFIGS
def
test_wav2vec2_recreate_pretraining_model
(
self
,
config
,
factory_func
):
self
.
_test_recreate_pretraining_model
(
config
,
factory_func
)
@
XLSR_PRETRAINING_CONFIGS
@
skipIfCudaSmallMemory
def
test_xlsr_recreate_pretraining_model
(
self
,
config
,
factory_func
):
self
.
_test_recreate_pretraining_model
(
config
,
factory_func
)
@
FINETUNING_CONFIGS
@
FINETUNING_CONFIGS
def
test_import_finetuning_model
(
self
,
config
,
_
):
def
test_import_finetuning_model
(
self
,
config
,
_
):
"""Fintuned wav2vec2 models from fairseq can be imported and yields the same results"""
"""Fintuned wav2vec2 models from fairseq can be imported and yields the same results"""
...
...
test/torchaudio_unittest/models/wav2vec2/huggingface_intergration_test.py
View file @
ffeba11a
import
json
import
json
import
unittest
import
torch
import
torch
from
parameterized
import
parameterized
from
parameterized
import
parameterized
from
torchaudio.models.wav2vec2
import
wav2vec2_base
,
wav2vec2_large
,
wav2vec2_large_lv60k
from
torchaudio.models.wav2vec2
import
(
wav2vec2_base
,
wav2vec2_large
,
wav2vec2_large_lv60k
,
wav2vec2_xlsr_1b
,
wav2vec2_xlsr_2b
,
wav2vec2_xlsr_300m
,
wavlm_base
,
wavlm_large
,
)
from
torchaudio.models.wav2vec2.utils
import
import_huggingface_model
from
torchaudio.models.wav2vec2.utils
import
import_huggingface_model
from
torchaudio_unittest.common_utils
import
get_asset_path
,
skipIfNoModule
,
TorchaudioTestCase
from
torchaudio_unittest.common_utils
import
(
get_asset_path
,
skipIfCudaSmallMemory
,
skipIfNoModule
,
TorchaudioTestCase
,
zip_equal
,
)
def
_load_config
(
*
paths
):
def
_load_config
(
*
paths
):
...
@@ -22,6 +38,11 @@ HF_LARGE = _load_config("wav2vec2-large")
...
@@ -22,6 +38,11 @@ HF_LARGE = _load_config("wav2vec2-large")
HF_LARGE_LV60
=
_load_config
(
"wav2vec2-large-lv60"
)
HF_LARGE_LV60
=
_load_config
(
"wav2vec2-large-lv60"
)
HF_LARGE_XLSR_53
=
_load_config
(
"wav2vec2-large-xlsr-53"
)
HF_LARGE_XLSR_53
=
_load_config
(
"wav2vec2-large-xlsr-53"
)
HF_BASE_10K_VOXPOPULI
=
_load_config
(
"wav2vec2-base-10k-voxpopuli"
)
HF_BASE_10K_VOXPOPULI
=
_load_config
(
"wav2vec2-base-10k-voxpopuli"
)
HF_BASE_WAVLM
=
_load_config
(
"wavlm-base"
)
HF_LARGE_WAVLM
=
_load_config
(
"wavlm-large"
)
HF_XLSR_300M
=
_load_config
(
"wav2vec2-xls-r-300m"
)
HF_XLSR_1B
=
_load_config
(
"wav2vec2-xls-r-1b"
)
HF_XLSR_2B
=
_load_config
(
"wav2vec2-xls-r-2b"
)
# Finetuned
# Finetuned
HF_BASE_960H
=
_load_config
(
"wav2vec2-base-960h"
)
HF_BASE_960H
=
_load_config
(
"wav2vec2-base-960h"
)
HF_LARGE_960H
=
_load_config
(
"wav2vec2-large-960h"
)
HF_LARGE_960H
=
_load_config
(
"wav2vec2-large-960h"
)
...
@@ -40,6 +61,14 @@ PRETRAIN_CONFIGS = parameterized.expand(
...
@@ -40,6 +61,14 @@ PRETRAIN_CONFIGS = parameterized.expand(
],
],
name_func
=
_name_func
,
name_func
=
_name_func
,
)
)
XLSR_PRETRAIN_CONFIGS
=
parameterized
.
expand
(
[
(
HF_XLSR_300M
,
wav2vec2_xlsr_300m
),
(
HF_XLSR_1B
,
wav2vec2_xlsr_1b
),
(
HF_XLSR_2B
,
wav2vec2_xlsr_2b
),
],
name_func
=
_name_func
,
)
FINETUNE_CONFIGS
=
parameterized
.
expand
(
FINETUNE_CONFIGS
=
parameterized
.
expand
(
[
[
(
HF_BASE_960H
,
wav2vec2_base
),
(
HF_BASE_960H
,
wav2vec2_base
),
...
@@ -50,8 +79,16 @@ FINETUNE_CONFIGS = parameterized.expand(
...
@@ -50,8 +79,16 @@ FINETUNE_CONFIGS = parameterized.expand(
],
],
name_func
=
_name_func
,
name_func
=
_name_func
,
)
)
WAVLM_CONFIGS
=
parameterized
.
expand
(
[
(
HF_BASE_WAVLM
,
wavlm_base
),
(
HF_LARGE_WAVLM
,
wavlm_large
),
],
name_func
=
_name_func
,
)
@
unittest
.
skip
(
"transformers v4.30 seems to break the weight format. See https://github.com/pytorch/audio/issues/3430"
)
@
skipIfNoModule
(
"transformers"
)
@
skipIfNoModule
(
"transformers"
)
class
TestHFIntegration
(
TorchaudioTestCase
):
class
TestHFIntegration
(
TorchaudioTestCase
):
"""Test the process of importing the models from Hugging Face Transformers
"""Test the process of importing the models from Hugging Face Transformers
...
@@ -68,12 +105,14 @@ class TestHFIntegration(TorchaudioTestCase):
...
@@ -68,12 +105,14 @@ class TestHFIntegration(TorchaudioTestCase):
# However, somehow, once "transformers" is imported, `is_module_available`
# However, somehow, once "transformers" is imported, `is_module_available`
# starts to fail. Therefore, we defer importing "transformers" until
# starts to fail. Therefore, we defer importing "transformers" until
# the actual tests are started.
# the actual tests are started.
from
transformers
.models.wav2vec2
import
Wav2Vec2Config
,
Wav2Vec2ForCTC
,
Wav2Vec2Model
from
transformers
import
Wav2Vec2Config
,
Wav2Vec2ForCTC
,
Wav2Vec2Model
,
WavLMConfig
,
WavLMModel
if
config
[
"architectures"
]
==
[
"Wav2Vec2Model"
]:
if
config
[
"architectures"
]
==
[
"Wav2Vec2Model"
]:
return
Wav2Vec2Model
(
Wav2Vec2Config
(
**
config
))
return
Wav2Vec2Model
(
Wav2Vec2Config
(
**
config
))
if
config
[
"architectures"
]
==
[
"Wav2Vec2ForCTC"
]:
if
config
[
"architectures"
]
==
[
"Wav2Vec2ForCTC"
]:
return
Wav2Vec2ForCTC
(
Wav2Vec2Config
(
**
config
))
return
Wav2Vec2ForCTC
(
Wav2Vec2Config
(
**
config
))
if
config
[
"architectures"
]
==
[
"WavLMModel"
]:
return
WavLMModel
(
WavLMConfig
(
**
config
))
raise
ValueError
(
f
'Unexpected arch:
{
config
[
"architectures"
]
}
'
)
raise
ValueError
(
f
'Unexpected arch:
{
config
[
"architectures"
]
}
'
)
def
_test_import_pretrain
(
self
,
original
,
imported
,
config
):
def
_test_import_pretrain
(
self
,
original
,
imported
,
config
):
...
@@ -97,9 +136,8 @@ class TestHFIntegration(TorchaudioTestCase):
...
@@ -97,9 +136,8 @@ class TestHFIntegration(TorchaudioTestCase):
b
,
l
,
e
=
16
,
3
,
config
[
"hidden_size"
]
b
,
l
,
e
=
16
,
3
,
config
[
"hidden_size"
]
x
=
torch
.
randn
(
b
,
l
,
e
)
x
=
torch
.
randn
(
b
,
l
,
e
)
mask
=
torch
.
randn
(
b
,
1
,
l
,
l
)
mask
=
torch
.
randn
(
b
,
1
,
l
,
l
)
(
ref
,)
=
original_
(
x
,
attention_mask
=
mask
,
output_attentions
=
False
)
(
ref
,)
=
original_
(
x
,
attention_mask
=
mask
,
output_attentions
=
False
)
hyp
=
imported_
(
x
,
mask
)
hyp
,
_
=
imported_
(
x
,
mask
)
# Ignore returned position_bias, which is always None for Wav2Vec2 and HuBERT
self
.
assertEqual
(
ref
,
hyp
)
self
.
assertEqual
(
ref
,
hyp
)
# The whole Encoder Transformer
# The whole Encoder Transformer
b
,
l
,
e
=
16
,
3
,
config
[
"hidden_size"
]
b
,
l
,
e
=
16
,
3
,
config
[
"hidden_size"
]
...
@@ -115,11 +153,6 @@ class TestHFIntegration(TorchaudioTestCase):
...
@@ -115,11 +153,6 @@ class TestHFIntegration(TorchaudioTestCase):
hyp
=
imported
.
aux
(
x
)
hyp
=
imported
.
aux
(
x
)
self
.
assertEqual
(
ref
,
hyp
)
self
.
assertEqual
(
ref
,
hyp
)
# The whole model without mask
# The whole model without mask
x
=
torch
.
randn
(
3
,
1024
)
ref
=
original
(
x
).
logits
hyp
,
_
=
imported
(
x
)
self
.
assertEqual
(
ref
,
hyp
)
# The whole model without mask
batch_size
,
num_frames
=
3
,
1024
batch_size
,
num_frames
=
3
,
1024
x
=
torch
.
randn
(
batch_size
,
num_frames
)
x
=
torch
.
randn
(
batch_size
,
num_frames
)
ref
=
original
(
x
).
logits
ref
=
original
(
x
).
logits
...
@@ -151,6 +184,14 @@ class TestHFIntegration(TorchaudioTestCase):
...
@@ -151,6 +184,14 @@ class TestHFIntegration(TorchaudioTestCase):
imported
=
import_huggingface_model
(
original
).
eval
()
imported
=
import_huggingface_model
(
original
).
eval
()
self
.
_test_import_pretrain
(
original
,
imported
,
config
)
self
.
_test_import_pretrain
(
original
,
imported
,
config
)
@
XLSR_PRETRAIN_CONFIGS
@
skipIfCudaSmallMemory
def
test_import_xlsr_pretrain
(
self
,
config
,
_
):
"""XLS-R models from HF transformers can be imported and yields the same results"""
original
=
self
.
_get_model
(
config
).
eval
()
imported
=
import_huggingface_model
(
original
).
eval
()
self
.
_test_import_pretrain
(
original
,
imported
,
config
)
@
FINETUNE_CONFIGS
@
FINETUNE_CONFIGS
def
test_import_finetune
(
self
,
config
,
_
):
def
test_import_finetune
(
self
,
config
,
_
):
"""wav2vec2 models from HF transformers can be imported and yields the same results"""
"""wav2vec2 models from HF transformers can be imported and yields the same results"""
...
@@ -159,6 +200,51 @@ class TestHFIntegration(TorchaudioTestCase):
...
@@ -159,6 +200,51 @@ class TestHFIntegration(TorchaudioTestCase):
self
.
_test_import_pretrain
(
original
.
wav2vec2
,
imported
,
config
)
self
.
_test_import_pretrain
(
original
.
wav2vec2
,
imported
,
config
)
self
.
_test_import_finetune
(
original
,
imported
,
config
)
self
.
_test_import_finetune
(
original
,
imported
,
config
)
@
WAVLM_CONFIGS
def
test_import_pretrain_wavlm
(
self
,
config
,
_
):
"""WavLM models from HF transformers can be imported and yield the same results"""
original
=
self
.
_get_model
(
config
).
eval
()
imported
=
import_huggingface_model
(
original
).
eval
()
# FeatureExtractor
x
=
torch
.
randn
(
3
,
1024
)
ref
=
original
.
feature_extractor
(
x
).
transpose
(
1
,
2
)
hyp
,
_
=
imported
.
feature_extractor
(
x
,
None
)
self
.
assertEqual
(
ref
,
hyp
)
# Feature projection
x
=
torch
.
randn
(
3
,
10
,
config
[
"conv_dim"
][
-
1
])
ref
=
original
.
feature_projection
(
x
)[
0
]
hyp
=
imported
.
encoder
.
feature_projection
(
x
)
self
.
assertEqual
(
ref
,
hyp
)
# Convolutional Positional Encoder
x
=
torch
.
randn
(
3
,
256
,
config
[
"hidden_size"
])
ref
=
original
.
encoder
.
pos_conv_embed
(
x
)
hyp
=
imported
.
encoder
.
transformer
.
pos_conv_embed
(
x
)
self
.
assertEqual
(
ref
,
hyp
)
position_bias
=
None
position_bias_imp
=
None
assert
len
(
original
.
encoder
.
layers
)
>
0
for
original_
,
imported_
in
zip_equal
(
original
.
encoder
.
layers
,
imported
.
encoder
.
transformer
.
layers
):
b
,
l
,
e
=
16
,
3
,
config
[
"hidden_size"
]
x
=
torch
.
randn
(
b
,
l
,
e
)
mask
=
torch
.
randn
(
b
,
l
)
>
0.5
# HF WaveLM model expects the mask to be binary
# HF WaveLM model (original_) takes in "attention mask" but actually uses it as key padding mask:
# https://github.com/huggingface/transformers/blob/b047472650cba259621549ac27b18fd2066ce18e/src/transformers/models/wavlm/modeling_wavlm.py#L495
ref
,
position_bias
=
original_
(
x
,
attention_mask
=
mask
,
output_attentions
=
False
,
position_bias
=
position_bias
)
hyp
,
position_bias_imp
=
imported_
(
x
,
key_padding_mask
=
mask
.
ne
(
1
),
position_bias
=
position_bias_imp
)
# Masked-out elements are undefined in the output
ref_filled
=
ref
.
masked_fill
(
~
mask
.
unsqueeze
(
2
),
0
)
hyp_filled
=
hyp
.
masked_fill
(
~
mask
.
unsqueeze
(
2
),
0
)
self
.
assertEqual
(
ref_filled
,
hyp_filled
)
# The whole Encoder Transformer
b
,
l
,
e
=
16
,
3
,
config
[
"hidden_size"
]
x
=
torch
.
randn
(
b
,
l
,
e
)
ref
=
original
.
encoder
(
x
).
last_hidden_state
hyp
=
imported
.
encoder
.
transformer
(
x
)
self
.
assertEqual
(
ref
,
hyp
)
def
_test_recreate
(
self
,
imported
,
reloaded
,
config
):
def
_test_recreate
(
self
,
imported
,
reloaded
,
config
):
# FeatureExtractor
# FeatureExtractor
x
=
torch
.
randn
(
3
,
1024
)
x
=
torch
.
randn
(
3
,
1024
)
...
@@ -221,3 +307,50 @@ class TestHFIntegration(TorchaudioTestCase):
...
@@ -221,3 +307,50 @@ class TestHFIntegration(TorchaudioTestCase):
reloaded
.
load_state_dict
(
imported
.
state_dict
())
reloaded
.
load_state_dict
(
imported
.
state_dict
())
reloaded
.
eval
()
reloaded
.
eval
()
self
.
_test_recreate
(
imported
,
reloaded
,
config
)
self
.
_test_recreate
(
imported
,
reloaded
,
config
)
@
WAVLM_CONFIGS
def
test_recreate_wavlm
(
self
,
config
,
factory_func
):
"""Imported models can be recreated via a factory function without Hugging Face transformers."""
imported
=
import_huggingface_model
(
self
.
_get_model
(
config
)).
eval
()
reloaded
=
factory_func
()
reloaded
.
load_state_dict
(
imported
.
state_dict
())
reloaded
.
eval
()
# FeatureExtractor
x
=
torch
.
randn
(
3
,
1024
)
ref
,
_
=
imported
.
feature_extractor
(
x
,
None
)
hyp
,
_
=
reloaded
.
feature_extractor
(
x
,
None
)
self
.
assertEqual
(
ref
,
hyp
)
# Feature projection
x
=
torch
.
randn
(
3
,
10
,
config
[
"conv_dim"
][
-
1
])
ref
=
imported
.
encoder
.
feature_projection
(
x
)
hyp
=
reloaded
.
encoder
.
feature_projection
(
x
)
self
.
assertEqual
(
ref
,
hyp
)
# Convolutional Positional Encoder
x
=
torch
.
randn
(
3
,
256
,
config
[
"hidden_size"
])
ref
=
imported
.
encoder
.
transformer
.
pos_conv_embed
(
x
)
hyp
=
reloaded
.
encoder
.
transformer
.
pos_conv_embed
(
x
)
self
.
assertEqual
(
ref
,
hyp
)
# Encoder Transformer Layer
position_bias_ref
=
None
position_bias_hyp
=
None
for
imported_
,
reloaded_
in
zip
(
imported
.
encoder
.
transformer
.
layers
,
reloaded
.
encoder
.
transformer
.
layers
):
b
,
l
,
e
=
16
,
3
,
config
[
"hidden_size"
]
x
=
torch
.
randn
(
b
,
l
,
e
)
mask
=
torch
.
randn
(
b
,
l
)
>
0.5
# HugginFace WaveLM expects the mask to be binary
ref
,
position_bias_ref
=
imported_
(
x
,
key_padding_mask
=
mask
,
position_bias
=
position_bias_ref
)
hyp
,
position_bias_hyp
=
reloaded_
(
x
,
key_padding_mask
=
mask
,
position_bias
=
position_bias_hyp
)
self
.
assertEqual
(
ref
,
hyp
)
# The whole Encoder Transformer
# TODO: Add mask pattern. Expected mask shapes and values are different.
b
,
l
,
e
=
16
,
3
,
config
[
"hidden_size"
]
x
=
torch
.
randn
(
b
,
l
,
e
)
mask
=
torch
.
randn
(
b
,
1
,
l
,
l
)
ref
=
imported
.
encoder
.
transformer
(
x
)
hyp
=
reloaded
.
encoder
.
transformer
(
x
)
self
.
assertEqual
(
ref
,
hyp
)
# The whole model
x
=
torch
.
randn
(
3
,
1024
)
ref
,
_
=
imported
(
x
)
hyp
,
_
=
reloaded
(
x
)
self
.
assertEqual
(
ref
,
hyp
)
test/torchaudio_unittest/models/wav2vec2/model_test.py
View file @
ffeba11a
...
@@ -15,6 +15,8 @@ from torchaudio.models.wav2vec2 import (
...
@@ -15,6 +15,8 @@ from torchaudio.models.wav2vec2 import (
wav2vec2_base
,
wav2vec2_base
,
wav2vec2_large
,
wav2vec2_large
,
wav2vec2_large_lv60k
,
wav2vec2_large_lv60k
,
wavlm_base
,
wavlm_large
,
)
)
from
torchaudio_unittest.common_utils
import
skipIfNoCuda
,
skipIfNoQengine
,
torch_script
,
TorchaudioTestCase
from
torchaudio_unittest.common_utils
import
skipIfNoCuda
,
skipIfNoQengine
,
torch_script
,
TorchaudioTestCase
...
@@ -41,6 +43,14 @@ factory_funcs = parameterized.expand(
...
@@ -41,6 +43,14 @@ factory_funcs = parameterized.expand(
name_func
=
_name_func
,
name_func
=
_name_func
,
)
)
factory_funcs_wavlm
=
parameterized
.
expand
(
[
(
wavlm_base
,),
(
wavlm_large
,),
],
name_func
=
_name_func
,
)
factory_funcs_hubert_pretrain
=
parameterized
.
expand
(
factory_funcs_hubert_pretrain
=
parameterized
.
expand
(
[
[
(
hubert_pretrain_base
,),
(
hubert_pretrain_base
,),
...
@@ -278,6 +288,131 @@ class TestWav2Vec2Model(TorchaudioTestCase):
...
@@ -278,6 +288,131 @@ class TestWav2Vec2Model(TorchaudioTestCase):
self
.
_test_quantize_torchscript
(
factory_func
(
aux_num_out
=
32
))
self
.
_test_quantize_torchscript
(
factory_func
(
aux_num_out
=
32
))
class
TestWavLMModel
(
TorchaudioTestCase
):
def
_smoke_test
(
self
,
model
,
device
,
dtype
):
model
=
model
.
to
(
device
=
device
,
dtype
=
dtype
)
model
=
model
.
eval
()
batch_size
,
num_frames
=
3
,
1024
waveforms
=
torch
.
randn
(
batch_size
,
num_frames
,
device
=
device
,
dtype
=
dtype
)
model
(
waveforms
)
@
parameterized
.
expand
([(
torch
.
float32
,),
(
torch
.
float64
,)])
def
test_cpu_smoke_test
(
self
,
dtype
):
model
=
wavlm_base
()
self
.
_smoke_test
(
model
,
torch
.
device
(
"cpu"
),
dtype
)
model
=
wavlm_base
(
aux_num_out
=
32
)
self
.
_smoke_test
(
model
,
torch
.
device
(
"cpu"
),
dtype
)
@
parameterized
.
expand
([(
torch
.
float32
,),
(
torch
.
float64
,)])
@
skipIfNoCuda
def
test_cuda_smoke_test
(
self
,
dtype
):
model
=
wavlm_base
()
self
.
_smoke_test
(
model
,
torch
.
device
(
"cuda"
),
dtype
)
model
=
wavlm_base
(
aux_num_out
=
32
)
self
.
_smoke_test
(
model
,
torch
.
device
(
"cuda"
),
dtype
)
def
_test_batch_consistency
(
self
,
model
):
model
.
eval
()
batch_size
,
max_frames
=
5
,
5
*
1024
waveforms
=
torch
.
randn
(
batch_size
,
max_frames
)
# Batch process
batch_logits
,
_
=
model
(
waveforms
)
# Par-sample process
for
i
in
range
(
batch_size
):
single_logit
,
_
=
model
(
waveforms
[
i
:
i
+
1
])
batch_logit
=
batch_logits
[
i
:
i
+
1
]
# Convert to probability so that it's easier to interpretate the diff
single_prob
=
F
.
softmax
(
single_logit
,
dim
=
2
)
batch_prob
=
F
.
softmax
(
batch_logit
,
dim
=
2
)
# We allow max atol=0.005 -> 0.5%
self
.
assertEqual
(
single_prob
,
batch_prob
,
atol
=
0.005
,
rtol
=
0
)
@
factory_funcs_wavlm
def
test_pretrain_batch_consistency
(
self
,
factory_func
):
"""Results from single process and batched process should be reasonably close"""
self
.
_test_batch_consistency
(
factory_func
())
@
factory_funcs_wavlm
def
test_finetune_batch_consistency
(
self
,
factory_func
):
"""Results from single process and batched process should be reasonably close"""
self
.
_test_batch_consistency
(
factory_func
(
aux_num_out
=
32
))
def
_test_torchscript
(
self
,
model
):
model
.
eval
()
batch_size
,
num_frames
=
3
,
1024
waveforms
=
torch
.
randn
(
batch_size
,
num_frames
)
# Compute results with original model
ref_out
,
ref_len
=
model
(
waveforms
)
# Compute results with scripted model
scripted
=
torch_script
(
model
)
hyp_out
,
hyp_len
=
scripted
(
waveforms
)
self
.
assertEqual
(
hyp_out
,
ref_out
)
self
.
assertEqual
(
hyp_len
,
ref_len
)
@
factory_funcs_wavlm
def
test_pretrain_torchscript
(
self
,
factory_func
):
"""WavLM model should be scriptable"""
self
.
_test_torchscript
(
factory_func
())
@
factory_funcs_wavlm
def
test_finetune_torchscript
(
self
,
factory_func
):
"""WavLM model with a head should be scriptable"""
self
.
_test_torchscript
(
factory_func
(
aux_num_out
=
32
))
def
_test_quantize_smoke_test
(
self
,
model
):
model
.
eval
()
batch_size
,
num_frames
=
3
,
1024
# Remove the weight normalization forward hook
model
.
encoder
.
transformer
.
pos_conv_embed
.
__prepare_scriptable__
()
quantized
=
tq
.
quantize_dynamic
(
model
,
qconfig_spec
=
{
torch
.
nn
.
Linear
},
dtype
=
torch
.
qint8
)
# A lazy way to check that Modules are different
assert
str
(
quantized
)
!=
str
(
model
),
"Dynamic quantization did not modify the module."
waveforms
=
torch
.
randn
(
batch_size
,
num_frames
)
_
,
_
=
quantized
(
waveforms
)
@
factory_funcs_wavlm
@
skipIfNoQengine
def
test_quantize
(
self
,
factory_func
):
"""WavLM should support basic quantization"""
self
.
_test_quantize_smoke_test
(
factory_func
(
aux_num_out
=
32
))
def
_test_quantize_torchscript
(
self
,
model
):
model
.
eval
()
batch_size
,
num_frames
=
3
,
1024
# Remove the weight normalization forward hook
model
.
encoder
.
transformer
.
pos_conv_embed
.
__prepare_scriptable__
()
quantized
=
tq
.
quantize_dynamic
(
model
,
qconfig_spec
=
{
torch
.
nn
.
Linear
},
dtype
=
torch
.
qint8
)
# A lazy way to check that Modules are different
assert
str
(
quantized
)
!=
str
(
model
),
"Dynamic quantization did not modify the module."
waveforms
=
torch
.
randn
(
batch_size
,
num_frames
)
ref_out
,
ref_len
=
quantized
(
waveforms
)
# Script
scripted
=
torch_script
(
quantized
)
hyp_out
,
hyp_len
=
scripted
(
waveforms
)
self
.
assertEqual
(
hyp_out
,
ref_out
)
self
.
assertEqual
(
hyp_len
,
ref_len
)
@
factory_funcs_wavlm
@
skipIfNoQengine
def
test_quantize_torchscript
(
self
,
factory_func
):
"""Quantized WavLM model should be scriptable"""
self
.
_test_quantize_torchscript
(
factory_func
(
aux_num_out
=
32
))
def
_compute_label_frame
(
audio_frame
:
int
)
->
int
:
def
_compute_label_frame
(
audio_frame
:
int
)
->
int
:
"""Compute number of frames in the label tensor based on
"""Compute number of frames in the label tensor based on
the number of frames in the audio tensor."""
the number of frames in the audio tensor."""
...
...
test/torchaudio_unittest/sox_effect/dataset_test.py
View file @
ffeba11a
import
os
import
os
import
platform
import
platform
import
sys
import
sys
import
unittest
from
concurrent.futures
import
ProcessPoolExecutor
from
concurrent.futures
import
ProcessPoolExecutor
from
typing
import
List
,
Tuple
from
typing
import
List
,
Tuple
from
unittest
import
skipIf
from
unittest
import
skipIf
...
@@ -94,8 +95,9 @@ class TestSoxEffectsDataset(TempDirMixin, PytorchTestCase):
...
@@ -94,8 +95,9 @@ class TestSoxEffectsDataset(TempDirMixin, PytorchTestCase):
loader
=
torch
.
utils
.
data
.
DataLoader
(
loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
dataset
,
batch_size
=
32
,
batch_size
=
32
,
num_workers
=
16
,
num_workers
=
4
,
worker_init_fn
=
init_random_seed
,
worker_init_fn
=
init_random_seed
,
multiprocessing_context
=
torch
.
multiprocessing
.
get_context
(
"spawn"
),
)
)
for
batch
in
loader
:
for
batch
in
loader
:
assert
batch
.
shape
==
(
32
,
2
,
2
*
sample_rate
)
assert
batch
.
shape
==
(
32
,
2
,
2
*
sample_rate
)
...
@@ -115,8 +117,9 @@ class TestSoxEffectsDataset(TempDirMixin, PytorchTestCase):
...
@@ -115,8 +117,9 @@ class TestSoxEffectsDataset(TempDirMixin, PytorchTestCase):
loader
=
torch
.
utils
.
data
.
DataLoader
(
loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
dataset
,
batch_size
=
32
,
batch_size
=
32
,
num_workers
=
16
,
num_workers
=
4
,
worker_init_fn
=
init_random_seed
,
worker_init_fn
=
init_random_seed
,
multiprocessing_context
=
torch
.
multiprocessing
.
get_context
(
"spawn"
),
)
)
for
batch
in
loader
:
for
batch
in
loader
:
assert
batch
.
shape
==
(
32
,
2
,
2
*
sample_rate
)
assert
batch
.
shape
==
(
32
,
2
,
2
*
sample_rate
)
...
@@ -131,6 +134,7 @@ def speed(path):
...
@@ -131,6 +134,7 @@ def speed(path):
return
torchaudio
.
sox_effects
.
apply_effects_tensor
(
wav
,
sample_rate
,
effects
)[
0
]
return
torchaudio
.
sox_effects
.
apply_effects_tensor
(
wav
,
sample_rate
,
effects
)[
0
]
@
unittest
.
skipIf
(
True
,
"Skipping this test because condition is True"
)
@
skipIfNoSox
@
skipIfNoSox
class
TestProcessPoolExecutor
(
TempDirMixin
,
PytorchTestCase
):
class
TestProcessPoolExecutor
(
TempDirMixin
,
PytorchTestCase
):
backend
=
"sox_io"
backend
=
"sox_io"
...
...
test/torchaudio_unittest/sox_effect/smoke_test.py
View file @
ffeba11a
...
@@ -54,24 +54,3 @@ class SmokeTest(TempDirMixin, TorchaudioTestCase):
...
@@ -54,24 +54,3 @@ class SmokeTest(TempDirMixin, TorchaudioTestCase):
_found
,
_sr
=
sox_effects
.
apply_effects_file
(
_found
,
_sr
=
sox_effects
.
apply_effects_file
(
input_path
,
effects
,
normalize
=
False
,
channels_first
=
channels_first
input_path
,
effects
,
normalize
=
False
,
channels_first
=
channels_first
)
)
@
parameterized
.
expand
(
load_params
(
"sox_effect_test_args.jsonl"
),
name_func
=
lambda
f
,
i
,
p
:
f
'
{
f
.
__name__
}
_
{
i
}
_
{
p
.
args
[
0
][
"effects"
][
0
][
0
]
}
'
,
)
def
test_apply_effects_fileobj
(
self
,
args
):
"""`apply_effects_file` should return identical data as sox command"""
dtype
=
"int32"
channels_first
=
True
effects
=
args
[
"effects"
]
num_channels
=
args
.
get
(
"num_channels"
,
2
)
input_sr
=
args
.
get
(
"input_sample_rate"
,
8000
)
input_path
=
self
.
get_temp_path
(
"input.wav"
)
data
=
get_wav_data
(
dtype
,
num_channels
,
channels_first
=
channels_first
)
save_wav
(
input_path
,
data
,
input_sr
,
channels_first
=
channels_first
)
with
open
(
input_path
,
"rb"
)
as
fileobj
:
_found
,
_sr
=
sox_effects
.
apply_effects_file
(
fileobj
,
effects
,
normalize
=
False
,
channels_first
=
channels_first
)
Prev
1
…
9
10
11
12
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment