Commit ffeba11a authored by mayp777's avatar mayp777
Browse files

UPDATE

parent 29deb085
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda, skipIfkmeMark
from .librosa_compatibility_test_impl import Functional, FunctionalComplex from .librosa_compatibility_test_impl import Functional, FunctionalComplex
@skipIfNoCuda @skipIfNoCuda
@skipIfkmeMark
class TestFunctionalCUDA(Functional, PytorchTestCase): class TestFunctionalCUDA(Functional, PytorchTestCase):
device = "cuda" device = "cuda"
@skipIfNoCuda @skipIfNoCuda
@skipIfkmeMark
class TestFunctionalComplexCUDA(FunctionalComplex, PytorchTestCase): class TestFunctionalComplexCUDA(FunctionalComplex, PytorchTestCase):
device = "cuda" device = "cuda"
import unittest import unittest
from distutils.version import StrictVersion from distutils.version import LooseVersion
import torch import torch
import torchaudio.functional as F import torchaudio.functional as F
...@@ -77,7 +77,7 @@ class Functional(TestBaseMixin): ...@@ -77,7 +77,7 @@ class Functional(TestBaseMixin):
def test_create_mel_fb( def test_create_mel_fb(
self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fmax=8000.0, norm=None, mel_scale="htk" self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fmax=8000.0, norm=None, mel_scale="htk"
): ):
if norm == "slaney" and StrictVersion(librosa.__version__) < StrictVersion("0.7.2"): if norm == "slaney" and LooseVersion(librosa.__version__) < LooseVersion("0.7.2"):
self.skipTest("Test is known to fail with older versions of librosa.") self.skipTest("Test is known to fail with older versions of librosa.")
if self.device != "cpu": if self.device != "cpu":
self.skipTest("No need to run this test on CUDA") self.skipTest("No need to run this test on CUDA")
......
import torch import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda, skipIfkmeMark
from .torchscript_consistency_impl import Functional, FunctionalFloat32Only from .torchscript_consistency_impl import Functional, FunctionalFloat32Only
...@@ -11,6 +11,7 @@ class TestFunctionalFloat32(Functional, FunctionalFloat32Only, PytorchTestCase): ...@@ -11,6 +11,7 @@ class TestFunctionalFloat32(Functional, FunctionalFloat32Only, PytorchTestCase):
@skipIfNoCuda @skipIfNoCuda
@skipIfkmeMark
class TestFunctionalFloat64(Functional, PytorchTestCase): class TestFunctionalFloat64(Functional, PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device("cuda") device = torch.device("cuda")
...@@ -585,22 +585,10 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -585,22 +585,10 @@ class Functional(TempDirMixin, TestBaseMixin):
tensor = common_utils.get_whitenoise(sample_rate=44100) tensor = common_utils.get_whitenoise(sample_rate=44100)
self._assert_consistency(func, (tensor,)) self._assert_consistency(func, (tensor,))
@common_utils.skipIfNoKaldi
def test_compute_kaldi_pitch(self):
if self.dtype != torch.float32 or self.device != torch.device("cpu"):
raise unittest.SkipTest("Only float32, cpu is supported.")
def func(tensor):
sample_rate: float = 44100.0
return F.compute_kaldi_pitch(tensor, sample_rate)
tensor = common_utils.get_whitenoise(sample_rate=44100)
self._assert_consistency(func, (tensor,))
def test_resample_sinc(self): def test_resample_sinc(self):
def func(tensor): def func(tensor):
sr1, sr2 = 16000, 8000 sr1, sr2 = 16000, 8000
return F.resample(tensor, sr1, sr2, resampling_method="sinc_interpolation") return F.resample(tensor, sr1, sr2, resampling_method="sinc_interp_hann")
tensor = common_utils.get_whitenoise(sample_rate=16000) tensor = common_utils.get_whitenoise(sample_rate=16000)
self._assert_consistency(func, (tensor,)) self._assert_consistency(func, (tensor,))
...@@ -616,7 +604,9 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -616,7 +604,9 @@ class Functional(TempDirMixin, TestBaseMixin):
sr1, sr2 = 16000, 8000 sr1, sr2 = 16000, 8000
lowpass_filter_width = 6 lowpass_filter_width = 6
rolloff = 0.99 rolloff = 0.99
self._assert_consistency(F.resample, (tensor, sr1, sr2, lowpass_filter_width, rolloff, "kaiser_window", beta)) self._assert_consistency(
F.resample, (tensor, sr1, sr2, lowpass_filter_width, rolloff, "sinc_interp_kaiser", beta)
)
def test_phase_vocoder(self): def test_phase_vocoder(self):
tensor = torch.view_as_complex(torch.randn(2, 1025, 400, 2)) tensor = torch.view_as_complex(torch.randn(2, 1025, 400, 2))
...@@ -756,6 +746,54 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -756,6 +746,54 @@ class Functional(TempDirMixin, TestBaseMixin):
specgram = torch.rand(num_channels, n_fft_bin, num_frames, dtype=self.complex_dtype, device=self.device) specgram = torch.rand(num_channels, n_fft_bin, num_frames, dtype=self.complex_dtype, device=self.device)
self._assert_consistency_complex(F.apply_beamforming, (beamform_weights, specgram)) self._assert_consistency_complex(F.apply_beamforming, (beamform_weights, specgram))
@common_utils.nested_params(
["convolve", "fftconvolve"],
["full", "valid", "same"],
)
def test_convolve(self, fn, mode):
leading_dims = (2, 3, 2)
L_x, L_y = 32, 55
x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device)
y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device)
self._assert_consistency(getattr(F, fn), (x, y, mode))
@common_utils.nested_params([True, False])
def test_add_noise(self, use_lengths):
leading_dims = (2, 3)
L = 31
waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device, requires_grad=True)
noise = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device, requires_grad=True)
if use_lengths:
lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True)
else:
lengths = None
snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device, requires_grad=True) * 10
self._assert_consistency(F.add_noise, (waveform, noise, snr, lengths))
@common_utils.nested_params([True, False])
def test_speed(self, use_lengths):
leading_dims = (3, 2)
T = 200
waveform = torch.rand(*leading_dims, T, dtype=self.dtype, device=self.device, requires_grad=True)
if use_lengths:
lengths = torch.randint(1, T, leading_dims, dtype=self.dtype, device=self.device)
else:
lengths = None
self._assert_consistency(F.speed, (waveform, 1000, 1.1, lengths))
def test_preemphasis(self):
waveform = torch.rand(3, 2, 100, device=self.device, dtype=self.dtype)
coeff = 0.9
self._assert_consistency(F.preemphasis, (waveform, coeff))
def test_deemphasis(self):
waveform = torch.rand(3, 2, 100, device=self.device, dtype=self.dtype)
coeff = 0.9
self._assert_consistency(F.deemphasis, (waveform, coeff))
class FunctionalFloat32Only(TestBaseMixin): class FunctionalFloat32Only(TestBaseMixin):
def test_rnnt_loss(self): def test_rnnt_loss(self):
......
import torchaudio
# If FFmpeg is 4.1 or older
# Tests that checks the number of output samples from OPUS fails
# They work on 4.2+
# Probably this commit fixed it.
# https://github.com/FFmpeg/FFmpeg/commit/18aea7bdd96b320a40573bccabea56afeccdd91c
def lt42():
ver = torchaudio.utils.ffmpeg_utils.get_versions()["libavcodec"]
# 5.1 libavcodec 59. 18.100
# 4.4 libavcodec 58.134.100
# 4.3 libavcodec 58. 91.100
# 4.2 libavcodec 58. 54.100
# 4.1 libavcodec 58. 35.100
return ver[0] < 59 and ver[1] < 54
from parameterized import parameterized
from torchaudio.io import AudioEffector
from torchaudio_unittest.common_utils import get_sinusoid, skipIfNoFFmpeg, TorchaudioTestCase
from .common import lt42
@skipIfNoFFmpeg
class EffectorTest(TorchaudioTestCase):
def test_null(self):
"""No effect and codec will return the same result"""
sample_rate = 8000
frames_per_chunk = 256
effector = AudioEffector(effect=None, format=None)
original = get_sinusoid(n_channels=3, sample_rate=sample_rate, channels_first=False)
# one-go
output = effector.apply(original, sample_rate)
self.assertEqual(original, output)
# streaming
for i, chunk in enumerate(effector.stream(original, sample_rate, frames_per_chunk)):
start = i * frames_per_chunk
end = (i + 1) * frames_per_chunk
self.assertEqual(original[start:end, :], chunk)
@parameterized.expand(
[
("ogg", "flac"), # flac only supports s16 and s32
("ogg", "opus"), # opus only supports 48k Hz
("ogg", "vorbis"), # vorbis only supports stereo
# ("ogg", "vorbis", 44100),
# this fails with small descrepancy; 441024 vs 441000
# TODO: investigate
("wav", None),
("wav", "pcm_u8"),
("mp3", None),
("mulaw", None, 44100), # mulaw is encoded without header
]
)
def test_formats(self, format, encoder, sample_rate=8000):
"""Formats (some with restrictions) just work without an issue in effector"""
effector = AudioEffector(format=format, encoder=encoder)
original = get_sinusoid(n_channels=3, sample_rate=sample_rate, channels_first=False)
output = effector.apply(original, sample_rate)
# On 4.1 OPUS produces 8020 samples (extra 20)
# this has been fixed on 4.2+
if encoder == "opus" and lt42():
return
self.assertEqual(original.shape, output.shape)
# Note
# MP3 adds padding which cannot be removed when the encoded data is written to
# file-like object without seek method.
# The number of padding is retrievable as `AVCoedcContext::initial_padding`
# https://ffmpeg.org/doxygen/4.1/structAVCodecContext.html#a8f95550ce04f236e9915516d04d3d1ab
# but this is not exposed yet.
# These "priming" samples have negative time stamp, so we can also add logic
# to discard them at decoding, however, as far as I checked, when data is loaded
# with StreamReader, the time stamp is reset. I tried options like avoid_negative_ts,
# https://ffmpeg.org/ffmpeg-formats.html
# but it made no difference. Perhaps this is because the information about negative
# timestamp is only available at encoding side, and it presumably is written to
# header file, but it is not happening somehow with file-like object.
# Need to investigate more to remove MP3 padding
if format == "mp3":
return
for chunk in effector.stream(original, sample_rate, frames_per_chunk=original.size(0)):
self.assertEqual(original.shape, chunk.shape)
@parameterized.expand([("loudnorm=I=-16:LRA=11:TP=-1.5",), ("volume=2",)])
def test_effect(self, effect):
sample_rate = 8000
effector = AudioEffector(effect=effect)
original = get_sinusoid(n_channels=3, sample_rate=sample_rate, channels_first=False)
output = effector.apply(original, sample_rate)
self.assertEqual(original.shape, output.shape)
def test_resample(self):
"""Resample option allows to change the sampling rate"""
sample_rate = 8000
output_sample_rate = 16000
num_channels = 3
effector = AudioEffector(effect="lowpass")
original = get_sinusoid(n_channels=num_channels, sample_rate=sample_rate, channels_first=False)
output = effector.apply(original, sample_rate, output_sample_rate)
self.assertEqual(output.shape, [output_sample_rate, num_channels])
for chunk in effector.stream(
original, sample_rate, output_sample_rate=output_sample_rate, frames_per_chunk=output_sample_rate
):
self.assertEqual(chunk.shape, [output_sample_rate, num_channels])
from unittest.mock import patch
import torch
from parameterized import parameterized
from torchaudio.io import play_audio, StreamWriter
from torchaudio_unittest.common_utils import get_sinusoid, skipIfNoAudioDevice, skipIfNoMacOS, TorchaudioTestCase
@skipIfNoAudioDevice
@skipIfNoMacOS
class PlaybackInterfaceTest(TorchaudioTestCase):
@parameterized.expand([("uint8",), ("int16",), ("int32",), ("int64",), ("float32",), ("float64",)])
@patch.object(StreamWriter, "write_audio_chunk")
def test_playaudio(self, dtype, writeaudio_mock):
"""Test playaudio function.
The patch object is used to check if the data is written
to the output device stream, without playing the actual audio.
"""
dtype = getattr(torch, dtype)
sample_rate = 8000
waveform = get_sinusoid(
frequency=440,
sample_rate=sample_rate,
duration=1, # seconds
n_channels=1,
dtype=dtype,
device="cpu",
channels_first=False,
)
play_audio(waveform, sample_rate=sample_rate)
writeaudio_mock.assert_called()
@parameterized.expand(
[
# Invalid number of dimensions (!= 2)
("int16", 1, "audiotoolbox"),
("int16", 3, "audiotoolbox"),
# Invalid tensor type
("complex64", 2, "audiotoolbox"),
# Invalid output device
("int16", 2, "audiotool"),
]
)
@patch.object(StreamWriter, "write_audio_chunk")
def test_playaudio_invalid_options(self, dtype, ndim, device, writeaudio_mock):
"""Test playaudio function raises error with invalid options."""
dtype = getattr(torch, dtype)
sample_rate = 8000
waveform = get_sinusoid(
frequency=440,
sample_rate=sample_rate,
duration=1, # seconds
n_channels=1,
dtype=dtype,
device="cpu",
channels_first=False,
).squeeze()
for _ in range(ndim - 1):
waveform = waveform.unsqueeze(-1)
with self.assertRaises(ValueError):
play_audio(waveform, sample_rate=sample_rate, device=device)
import io
import torch import torch
import torchaudio import torchaudio
from parameterized import parameterized, parameterized_class from parameterized import parameterized, parameterized_class
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
disabledInCI,
get_asset_path, get_asset_path,
get_image, get_image,
get_sinusoid,
get_wav_data, get_wav_data,
is_ffmpeg_available, is_ffmpeg_available,
nested_params, nested_params,
...@@ -12,24 +16,68 @@ from torchaudio_unittest.common_utils import ( ...@@ -12,24 +16,68 @@ from torchaudio_unittest.common_utils import (
save_image, save_image,
save_wav, save_wav,
skipIfNoFFmpeg, skipIfNoFFmpeg,
skipIfNoHWAccel,
TempDirMixin, TempDirMixin,
TorchaudioTestCase, TorchaudioTestCase,
) )
if is_ffmpeg_available(): if is_ffmpeg_available():
from torchaudio.io import ( from torchaudio.io import StreamReader, StreamWriter
StreamReader, from torchaudio.io._stream_reader import (
StreamReaderSourceAudioStream, ChunkTensor,
StreamReaderSourceStream, OutputAudioStream,
StreamReaderSourceVideoStream, OutputVideoStream,
SourceAudioStream,
SourceStream,
SourceVideoStream,
) )
@skipIfNoFFmpeg
class ChunkTensorTest(TorchaudioTestCase):
def test_chunktensor(self):
"""ChunkTensor serves as a replacement of tensor"""
data = torch.randn((256, 2))
pts = 16.0
c = ChunkTensor(data, pts)
assert c.pts == pts
self.assertEqual(c, data)
# method
sum_ = c.sum()
assert isinstance(sum_, torch.Tensor)
self.assertEqual(sum_, data.sum())
# function form
min_ = torch.min(c)
assert isinstance(min_, torch.Tensor)
self.assertEqual(min_, torch.min(data))
# attribute
t = c.T
assert isinstance(t, torch.Tensor)
self.assertEqual(t, data.T)
# in-place op
c[0] = 0
self.assertEqual(c, data)
# pass to other C++ code
buffer = io.BytesIO()
w = StreamWriter(buffer, format="wav")
w.add_audio_stream(8000, 2)
with w.open():
w.write_audio_chunk(0, c)
w.write_audio_chunk(0, c, c.pts)
################################################################################ ################################################################################
# Helper decorator and Mixin to duplicate the tests for fileobj # Helper decorator and Mixin to duplicate the tests for fileobj
_media_source = parameterized_class( _media_source = parameterized_class(
("test_type",), ("test_type",),
[("str",), ("fileobj",), ("tensor",)], [("str",), ("fileobj",)],
class_name_func=lambda cls, _, params: f'{cls.__name__}_{params["test_type"]}', class_name_func=lambda cls, _, params: f'{cls.__name__}_{params["test_type"]}',
) )
...@@ -47,13 +95,6 @@ class _MediaSourceMixin: ...@@ -47,13 +95,6 @@ class _MediaSourceMixin:
self.src = path self.src = path
elif self.test_type == "fileobj": elif self.test_type == "fileobj":
self.src = open(path, "rb") self.src = open(path, "rb")
elif self.test_type == "tensor":
with open(path, "rb") as fileobj:
data = fileobj.read()
self.src = torch.frombuffer(data, dtype=torch.uint8)
print(self.src.data_ptr())
print(len(data))
print(self.src.shape)
return self.src return self.src
def tearDown(self): def tearDown(self):
...@@ -112,7 +153,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -112,7 +153,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
base_metadata = {} base_metadata = {}
expected = [ expected = [
StreamReaderSourceVideoStream( SourceVideoStream(
media_type="video", media_type="video",
codec="h264", codec="h264",
codec_long_name="H.264 / AVC / MPEG-4 AVC / MPEG-4 part 10", codec_long_name="H.264 / AVC / MPEG-4 AVC / MPEG-4 part 10",
...@@ -129,7 +170,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -129,7 +170,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
height=180, height=180,
frame_rate=25.0, frame_rate=25.0,
), ),
StreamReaderSourceAudioStream( SourceAudioStream(
media_type="audio", media_type="audio",
codec="aac", codec="aac",
codec_long_name="AAC (Advanced Audio Coding)", codec_long_name="AAC (Advanced Audio Coding)",
...@@ -145,7 +186,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -145,7 +186,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
sample_rate=8000.0, sample_rate=8000.0,
num_channels=2, num_channels=2,
), ),
StreamReaderSourceStream( SourceStream(
media_type="subtitle", media_type="subtitle",
codec="mov_text", codec="mov_text",
codec_long_name="MOV text", codec_long_name="MOV text",
...@@ -158,7 +199,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -158,7 +199,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
"language": "eng", "language": "eng",
}, },
), ),
StreamReaderSourceVideoStream( SourceVideoStream(
media_type="video", media_type="video",
codec="h264", codec="h264",
codec_long_name="H.264 / AVC / MPEG-4 AVC / MPEG-4 part 10", codec_long_name="H.264 / AVC / MPEG-4 AVC / MPEG-4 part 10",
...@@ -175,7 +216,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -175,7 +216,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
height=270, height=270,
frame_rate=29.97002997002997, frame_rate=29.97002997002997,
), ),
StreamReaderSourceAudioStream( SourceAudioStream(
media_type="audio", media_type="audio",
codec="aac", codec="aac",
codec_long_name="AAC (Advanced Audio Coding)", codec_long_name="AAC (Advanced Audio Coding)",
...@@ -191,7 +232,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -191,7 +232,7 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
sample_rate=16000.0, sample_rate=16000.0,
num_channels=2, num_channels=2,
), ),
StreamReaderSourceStream( SourceStream(
media_type="subtitle", media_type="subtitle",
codec="mov_text", codec="mov_text",
codec_long_name="MOV text", codec_long_name="MOV text",
...@@ -208,6 +249,98 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -208,6 +249,98 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
output = [s.get_src_stream_info(i) for i in range(6)] output = [s.get_src_stream_info(i) for i in range(6)]
assert expected == output assert expected == output
def test_output_info(self):
s = StreamReader(self.get_src())
s.add_audio_stream(-1)
s.add_audio_stream(-1, filter_desc="aresample=8000")
s.add_audio_stream(-1, filter_desc="aformat=sample_fmts=s16p")
s.add_video_stream(-1)
s.add_video_stream(-1, filter_desc="fps=10")
s.add_video_stream(-1, filter_desc="format=rgb24")
s.add_video_stream(-1, filter_desc="scale=w=160:h=90")
# Note:
# Somehow only FFmpeg 5 reports invalid video frame rate. (24576/0)
# FFmpeg 4 and 6 work fine.
# Perhaps this is a regression in FFmpeg or it could actually originate
# from other libraries.
# It consistently fails with FFmpeg installed via conda, so we change
# the value based on FFmpeg version.
ver = torchaudio.utils.ffmpeg_utils.get_versions()["libavutil"]
print(ver)
major, minor, _ = ver
if major == 57:
video_frame_rate = -1
else:
video_frame_rate = 30000 / 1001
print(video_frame_rate)
expected = [
OutputAudioStream(
source_index=4,
filter_description="anull",
media_type="audio",
format="fltp",
sample_rate=16000.0,
num_channels=2,
),
OutputAudioStream(
source_index=4,
filter_description="aresample=8000",
media_type="audio",
format="fltp",
sample_rate=8000.0,
num_channels=2,
),
OutputAudioStream(
source_index=4,
filter_description="aformat=sample_fmts=s16p",
media_type="audio",
format="s16p",
sample_rate=16000.0,
num_channels=2,
),
OutputVideoStream(
source_index=3,
filter_description="null",
media_type="video",
format="yuv420p",
width=480,
height=270,
frame_rate=30000 / 1001,
),
OutputVideoStream(
source_index=3,
filter_description="fps=10",
media_type="video",
format="yuv420p",
width=480,
height=270,
frame_rate=10,
),
OutputVideoStream(
source_index=3,
filter_description="format=rgb24",
media_type="video",
format="rgb24",
width=480,
height=270,
frame_rate=30000 / 1001,
),
OutputVideoStream(
source_index=3,
filter_description="scale=w=160:h=90",
media_type="video",
format="yuv420p",
width=160,
height=90,
frame_rate=30000 / 1001,
),
]
output = [s.get_out_stream_info(i) for i in range(s.num_out_streams)]
assert expected == output
def test_id3tag(self): def test_id3tag(self):
"""get_metadata method can fetch id3tag properly""" """get_metadata method can fetch id3tag properly"""
s = StreamReader(self.get_src("steam-train-whistle-daniel_simon.mp3")) s = StreamReader(self.get_src("steam-train-whistle-daniel_simon.mp3"))
...@@ -418,15 +551,26 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -418,15 +551,26 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
if i >= 40: if i >= 40:
break break
def test_seek(self): def test_stream_requires_grad_false(self):
"""Tensors produced by StreamReader are requires_grad=False"""
s = StreamReader(self.get_src())
s.add_basic_audio_stream(frames_per_chunk=2000)
s.add_basic_video_stream(frames_per_chunk=15)
s.fill_buffer()
audio, video = s.pop_chunks()
assert not audio._elem.requires_grad
assert not video._elem.requires_grad
@parameterized.expand(["key", "any", "precise"])
def test_seek(self, mode):
"""Calling `seek` multiple times should not segfault""" """Calling `seek` multiple times should not segfault"""
s = StreamReader(self.get_src()) s = StreamReader(self.get_src())
for i in range(10): for i in range(10):
s.seek(i) s.seek(i, mode)
for _ in range(0): for _ in range(0):
s.seek(0) s.seek(0, mode)
for i in range(10, 0, -1): for i in range(10, 0, -1):
s.seek(i) s.seek(i, mode)
def test_seek_negative(self): def test_seek_negative(self):
"""Calling `seek` with negative value should raise an exception""" """Calling `seek` with negative value should raise an exception"""
...@@ -434,6 +578,232 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -434,6 +578,232 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
s.seek(-1.0) s.seek(-1.0)
def test_seek_invalid_mode(self):
"""Calling `seek` with an invalid model should raise an exception"""
s = StreamReader(self.get_src())
with self.assertRaises(ValueError):
s.seek(10, "magic_seek")
@parameterized.expand(
[
# Test keyframe seek
# The source mp4 video has two key frames the first frame and 203rd frame at 8.08 second.
# If the seek time stamp is smaller than 8.08, it will seek into the first frame at 0.0 second.
("nasa_13013.mp4", "key", 0.2, (0, slice(None))),
("nasa_13013.mp4", "key", 8.04, (0, slice(None))),
("nasa_13013.mp4", "key", 8.08, (0, slice(202, None))),
("nasa_13013.mp4", "key", 8.12, (0, slice(202, None))),
# The source avi video has one keyframe every twelve frames 0, 12, 24,.. or every 0.4004 seconds.
# if we seek to a time stamp smaller than 0.4004 it will seek into the first frame at 0.0 second.
("nasa_13013.avi", "key", 0.2, (0, slice(None))),
("nasa_13013.avi", "key", 1.01, (0, slice(24, None))),
("nasa_13013.avi", "key", 7.37, (0, slice(216, None))),
("nasa_13013.avi", "key", 7.7, (0, slice(216, None))),
# Test precise seek
("nasa_13013.mp4", "precise", 0.0, (0, slice(None))),
("nasa_13013.mp4", "precise", 0.2, (0, slice(5, None))),
("nasa_13013.mp4", "precise", 8.04, (0, slice(201, None))),
("nasa_13013.mp4", "precise", 8.08, (0, slice(202, None))),
("nasa_13013.mp4", "precise", 8.12, (0, slice(203, None))),
("nasa_13013.avi", "precise", 0.0, (0, slice(None))),
("nasa_13013.avi", "precise", 0.2, (0, slice(1, None))),
("nasa_13013.avi", "precise", 8.1, (0, slice(238, None))),
("nasa_13013.avi", "precise", 8.14, (0, slice(239, None))),
("nasa_13013.avi", "precise", 8.17, (0, slice(240, None))),
# Test precise seek on video with missing PTS
("RATRACE_wave_f_nm_np1_fr_goo_37.avi", "precise", 0.0, (0, slice(None))),
("RATRACE_wave_f_nm_np1_fr_goo_37.avi", "precise", 0.2, (0, slice(4, None))),
("RATRACE_wave_f_nm_np1_fr_goo_37.avi", "precise", 0.3, (0, slice(7, None))),
# Test any seek
# The source avi video has one keyframe every twelve frames 0, 12, 24,.. or every 0.4004 seconds.
("nasa_13013.avi", "any", 0.0, (0, slice(None))),
("nasa_13013.avi", "any", 0.56, (0, slice(12, None))),
("nasa_13013.avi", "any", 7.77, (0, slice(228, None))),
("nasa_13013.avi", "any", 0.2002, (11, slice(12, None))),
("nasa_13013.avi", "any", 0.233567, (10, slice(12, None))),
("nasa_13013.avi", "any", 0.266933, (9, slice(12, None))),
]
)
def test_seek_modes(self, src, mode, seek_time, ref_indices):
"""We expect the following behaviour from the diferent kinds of seek:
- `key`: the reader will seek to the first keyframe from the timestamp given
- `precise`: the reader will seek to the first keyframe from the timestamp given
and start decoding from that position until the given timestmap (discarding all frames in between)
- `any`: the reader will seek to the colsest frame to the timestamp
given but if this is not a keyframe, the content will be the delta from other frames
To thest this behaviour we can parameterize the test with the tupple ref_indices. ref_indices[0]
is the expected index on the frames list decoded after seek and ref_indices[1] is exepected index for
the list of all frames decoded from the begining (reference frames). This test checks if
the reference frame at index ref_indices[1] is the same as ref_indices[0]. Plese note that with `any`
and `key` seek we only compare keyframes, but with `precise` seek we can compare any frame content.
"""
# Using the first video stream (which is not default video stream)
stream_index = 0
# Decode all frames for reference
src_bin = self.get_src(src)
s = StreamReader(src_bin)
s.add_basic_video_stream(-1, stream_index=stream_index)
s.process_all_packets()
(ref_frames,) = s.pop_chunks()
s.seek(seek_time, mode=mode)
s.process_all_packets()
(frame,) = s.pop_chunks()
hyp_index, ref_index = ref_indices
hyp, ref = frame[hyp_index:], ref_frames[ref_index]
print(hyp.shape, ref.shape)
self.assertEqual(hyp, ref)
@parameterized.expand(
[
("nasa_13013.mp4", [195, 3, 270, 480]),
# RATRACE does not have valid PTS metadata.
("RATRACE_wave_f_nm_np1_fr_goo_37.avi", [36, 3, 240, 560]),
]
)
def test_change_fps(self, src, shape):
"""Can change the FPS of videos"""
tgt_frame_rate = 15
s = StreamReader(self.get_src(src))
info = s.get_src_stream_info(s.default_video_stream)
assert info.frame_rate != tgt_frame_rate
s.add_basic_video_stream(frames_per_chunk=-1, frame_rate=tgt_frame_rate)
s.process_all_packets()
(chunk,) = s.pop_chunks()
assert chunk.shape == torch.Size(shape)
def test_invalid_chunk_option(self):
"""Passing invalid `frames_per_chunk` and `buffer_chunk_size` raises error"""
s = StreamReader(self.get_src())
for fpc, bcs in ((0, 3), (3, 0), (-2, 3), (3, -2)):
with self.assertRaises(RuntimeError):
s.add_audio_stream(frames_per_chunk=fpc, buffer_chunk_size=bcs)
with self.assertRaises(RuntimeError):
s.add_video_stream(frames_per_chunk=fpc, buffer_chunk_size=bcs)
def test_unchunked_stream(self):
"""`frames_per_chunk=-1` disable chunking.
When chunking is disabled, frames contained in one AVFrame become one chunk.
For video, that is always one frame, but for audio, it depends.
"""
s = StreamReader(self.get_src())
s.add_video_stream(frames_per_chunk=-1, buffer_chunk_size=10000)
s.add_audio_stream(frames_per_chunk=-1, buffer_chunk_size=10000)
s.process_all_packets()
video, audio = s.pop_chunks()
assert video.shape == torch.Size([390, 3, 270, 480])
assert audio.shape == torch.Size([208896, 2])
@parameterized.expand([(1,), (3,), (5,), (10,)])
def test_frames_per_chunk(self, fpc):
"""Changing frames_per_chunk does not change the returned content"""
src = self.get_src()
s = StreamReader(src)
s.add_video_stream(frames_per_chunk=-1, buffer_chunk_size=-1)
s.add_audio_stream(frames_per_chunk=-1, buffer_chunk_size=-1)
s.process_all_packets()
ref_video, ref_audio = s.pop_chunks()
if self.test_type == "fileobj":
src.seek(0)
s = StreamReader(src)
s.add_video_stream(frames_per_chunk=fpc, buffer_chunk_size=-1)
s.add_audio_stream(frames_per_chunk=fpc, buffer_chunk_size=-1)
chunks = list(s.stream())
video_chunks = torch.cat([c[0] for c in chunks if c[0] is not None])
audio_chunks = torch.cat([c[1] for c in chunks if c[1] is not None])
self.assertEqual(ref_video, video_chunks)
self.assertEqual(ref_audio, audio_chunks)
def test_buffer_chunk_size(self):
"""`buffer_chunk_size=-1` does not drop frames."""
src = self.get_src()
s = StreamReader(src)
s.add_video_stream(frames_per_chunk=30, buffer_chunk_size=-1)
s.add_audio_stream(frames_per_chunk=16000, buffer_chunk_size=-1)
s.process_all_packets()
for _ in range(13):
video, audio = s.pop_chunks()
assert video.shape == torch.Size([30, 3, 270, 480])
assert audio.shape == torch.Size([16000, 2])
video, audio = s.pop_chunks()
assert video is None
assert audio.shape == torch.Size([896, 2])
if self.test_type == "fileobj":
src.seek(0)
s = StreamReader(src)
s.add_video_stream(frames_per_chunk=30, buffer_chunk_size=3)
s.add_audio_stream(frames_per_chunk=16000, buffer_chunk_size=3)
s.process_all_packets()
for _ in range(2):
video, audio = s.pop_chunks()
assert video.shape == torch.Size([30, 3, 270, 480])
assert audio.shape == torch.Size([16000, 2])
video, audio = s.pop_chunks()
assert video.shape == torch.Size([30, 3, 270, 480])
assert audio.shape == torch.Size([896, 2])
@parameterized.expand([(1,), (3,), (5,), (10,)])
def test_video_pts(self, fpc):
"""PTS values of the first frame are reported in .pts attribute"""
rate, num_frames = 30000 / 1001, 390
ref_pts = [i / rate for i in range(0, num_frames, fpc)]
s = StreamReader(self.get_src())
s.add_video_stream(fpc)
pts = [video.pts for video, in s.stream()]
self.assertEqual(pts, ref_pts)
@parameterized.expand([(256,), (512,), (1024,), (4086,)])
def test_audio_pts(self, fpc):
"""PTS values of the first frame are reported in .pts attribute"""
rate, num_frames = 16000, 208896
ref_pts = [i / rate for i in range(0, num_frames, fpc)]
s = StreamReader(self.get_src())
s.add_audio_stream(fpc, buffer_chunk_size=-1)
pts = [audio.pts for audio, in s.stream()]
self.assertEqual(pts, ref_pts)
def test_pts_unchunked_process_all(self):
"""PTS is zero when loading the entire media with unchunked buffer"""
s = StreamReader(self.get_src())
s.add_audio_stream(-1, buffer_chunk_size=-1)
s.add_video_stream(-1, buffer_chunk_size=-1)
s.process_all_packets()
audio, video = s.pop_chunks()
assert audio.pts == 0.0
assert video.pts == 0.0
assert audio.size(0) == 208896
assert video.size(0) == 390
def test_pts_unchunked(self):
"""PTS grows proportionally to the number of frames decoded"""
s = StreamReader(self.get_src())
s.add_audio_stream(-1, buffer_chunk_size=-1)
s.add_video_stream(-1, buffer_chunk_size=-1)
num_audio_frames, num_video_frames = 0, 0
while num_audio_frames < 208896 and num_video_frames < 390:
s.process_packet()
audio, video = s.pop_chunks()
if audio is None and video is None:
continue
if audio is not None:
assert audio.pts == num_audio_frames / 16000
num_audio_frames += audio.size(0)
if video is not None:
assert video.pts == num_video_frames * 1001 / 30000
num_video_frames += video.size(0)
def _to_fltp(original): def _to_fltp(original):
"""Convert Tensor to float32 with value range [-1, 1]""" """Convert Tensor to float32 with value range [-1, 1]"""
...@@ -493,11 +863,84 @@ class StreamReaderAudioTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase) ...@@ -493,11 +863,84 @@ class StreamReaderAudioTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase)
if self.test_type == "fileobj": if self.test_type == "fileobj":
src.seek(0) src.seek(0)
self._test_wav(src, original, fmt=None) self._test_wav(src, original, fmt=None)
# convert to float32
expected = _to_fltp(original) def test_audio_stream_format(self):
if self.test_type == "fileobj": "`format` argument properly changes the sample format of decoded audio"
src.seek(0) num_channels = 2
self._test_wav(src, expected, fmt="fltp") src, s32 = self.get_src(8000, dtype="int32", num_channels=num_channels)
args = {
"num_channels": num_channels,
"normalize": False,
"channels_first": False,
"num_frames": 1 << 16,
}
u8 = get_wav_data("uint8", **args)
s16 = get_wav_data("int16", **args)
s64 = s32.to(torch.int64) * (1 << 32)
f32 = get_wav_data("float32", **args)
f64 = get_wav_data("float64", **args)
s = StreamReader(src)
s.add_basic_audio_stream(frames_per_chunk=-1, format="u8")
s.add_basic_audio_stream(frames_per_chunk=-1, format="u8p")
s.add_basic_audio_stream(frames_per_chunk=-1, format="s16")
s.add_basic_audio_stream(frames_per_chunk=-1, format="s16p")
s.add_basic_audio_stream(frames_per_chunk=-1, format="s32")
s.add_basic_audio_stream(frames_per_chunk=-1, format="s32p")
s.add_basic_audio_stream(frames_per_chunk=-1, format="s64")
s.add_basic_audio_stream(frames_per_chunk=-1, format="s64p")
s.add_basic_audio_stream(frames_per_chunk=-1, format="flt")
s.add_basic_audio_stream(frames_per_chunk=-1, format="fltp")
s.add_basic_audio_stream(frames_per_chunk=-1, format="dbl")
s.add_basic_audio_stream(frames_per_chunk=-1, format="dblp")
s.process_all_packets()
chunks = s.pop_chunks()
self.assertEqual(chunks[0], u8, atol=1, rtol=0)
self.assertEqual(chunks[1], u8, atol=1, rtol=0)
self.assertEqual(chunks[2], s16)
self.assertEqual(chunks[3], s16)
self.assertEqual(chunks[4], s32)
self.assertEqual(chunks[5], s32)
self.assertEqual(chunks[6], s64)
self.assertEqual(chunks[7], s64)
self.assertEqual(chunks[8], f32)
self.assertEqual(chunks[9], f32)
self.assertEqual(chunks[10], f64)
self.assertEqual(chunks[11], f64)
@nested_params([4000, 16000])
def test_basic_audio_stream_sample_rate(self, sr):
"""`sample_rate` argument changes the sample_rate of decoded audio"""
src_num_channels, src_sr = 2, 8000
data = get_sinusoid(sample_rate=src_sr, n_channels=src_num_channels, channels_first=False)
path = self.get_temp_path("ref.wav")
save_wav(path, data, src_sr, channels_first=False)
s = StreamReader(path)
s.add_basic_audio_stream(frames_per_chunk=-1, format="flt", sample_rate=sr)
self.assertEqual(s.get_src_stream_info(0).sample_rate, src_sr)
self.assertEqual(s.get_out_stream_info(0).sample_rate, sr)
s.process_all_packets()
(chunks,) = s.pop_chunks()
self.assertEqual(chunks.shape, [sr, src_num_channels])
@nested_params([1, 2, 3, 8, 16])
def test_basic_audio_stream_num_channels(self, num_channels):
"""`sample_rate` argument changes the number of channels of decoded audio"""
src_num_channels, sr = 2, 8000
data = get_sinusoid(sample_rate=sr, n_channels=src_num_channels, channels_first=False)
path = self.get_temp_path("ref.wav")
save_wav(path, data, sr, channels_first=False)
s = StreamReader(path)
s.add_basic_audio_stream(frames_per_chunk=-1, format="flt", num_channels=num_channels)
self.assertEqual(s.get_src_stream_info(0).num_channels, src_num_channels)
self.assertEqual(s.get_out_stream_info(0).num_channels, num_channels)
s.process_all_packets()
(chunks,) = s.pop_chunks()
self.assertEqual(chunks.shape, [sr, num_channels])
@nested_params( @nested_params(
["int16", "uint8", "int32"], # "float", "double", "int64"] ["int16", "uint8", "int32"], # "float", "double", "int64"]
...@@ -630,23 +1073,192 @@ class StreamReaderImageTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase) ...@@ -630,23 +1073,192 @@ class StreamReaderImageTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase)
rgb = torch.empty(1, 3, 256, 256, dtype=torch.uint8) rgb = torch.empty(1, 3, 256, 256, dtype=torch.uint8)
rgb[0, 0] = torch.arange(256, dtype=torch.uint8).reshape([1, -1]) rgb[0, 0] = torch.arange(256, dtype=torch.uint8).reshape([1, -1])
rgb[0, 1] = torch.arange(256, dtype=torch.uint8).reshape([-1, 1]) rgb[0, 1] = torch.arange(256, dtype=torch.uint8).reshape([-1, 1])
alpha = torch.full((1, 1, 256, 256), 255, dtype=torch.uint8)
for i in range(256): for i in range(256):
rgb[0, 2] = i rgb[0, 2] = i
path = self.get_temp_path(f"ref_{i}.png") path = self.get_temp_path(f"ref_{i}.png")
save_image(path, rgb[0], mode="RGB") save_image(path, rgb[0], mode="RGB")
rgb16 = ((rgb.to(torch.int32) - 128) << 8).to(torch.int16)
yuv = rgb_to_yuv_ccir(rgb) yuv = rgb_to_yuv_ccir(rgb)
yuv16 = yuv.to(torch.int16) * 4
bgr = rgb[:, [2, 1, 0], :, :] bgr = rgb[:, [2, 1, 0], :, :]
gray = rgb_to_gray(rgb) gray = rgb_to_gray(rgb)
argb = torch.cat([alpha, rgb], dim=1)
rgba = torch.cat([rgb, alpha], dim=1)
abgr = torch.cat([alpha, bgr], dim=1)
bgra = torch.cat([bgr, alpha], dim=1)
s = StreamReader(path) s = StreamReader(path)
s.add_basic_video_stream(frames_per_chunk=-1, format="yuv444p") s.add_basic_video_stream(frames_per_chunk=-1, format="yuv444p")
s.add_basic_video_stream(frames_per_chunk=-1, format="yuv420p")
s.add_basic_video_stream(frames_per_chunk=-1, format="nv12")
s.add_basic_video_stream(frames_per_chunk=-1, format="rgb24") s.add_basic_video_stream(frames_per_chunk=-1, format="rgb24")
s.add_basic_video_stream(frames_per_chunk=-1, format="bgr24") s.add_basic_video_stream(frames_per_chunk=-1, format="bgr24")
s.add_basic_video_stream(frames_per_chunk=-1, format="gray8") s.add_basic_video_stream(frames_per_chunk=-1, format="gray8")
s.add_basic_video_stream(frames_per_chunk=-1, format="rgb48le")
s.add_basic_video_stream(frames_per_chunk=-1, format="argb")
s.add_basic_video_stream(frames_per_chunk=-1, format="rgba")
s.add_basic_video_stream(frames_per_chunk=-1, format="abgr")
s.add_basic_video_stream(frames_per_chunk=-1, format="bgra")
s.add_basic_video_stream(frames_per_chunk=-1, format="yuv420p10le")
s.process_all_packets() s.process_all_packets()
output_yuv, output_rgb, output_bgr, output_gray = s.pop_chunks() chunks = s.pop_chunks()
self.assertEqual(yuv, output_yuv, atol=1, rtol=0) self.assertEqual(chunks[0], yuv, atol=1, rtol=0)
self.assertEqual(rgb, output_rgb, atol=0, rtol=0) self.assertEqual(chunks[1], yuv, atol=1, rtol=0)
self.assertEqual(bgr, output_bgr, atol=0, rtol=0) self.assertEqual(chunks[2], yuv, atol=1, rtol=0)
self.assertEqual(gray, output_gray, atol=1, rtol=0) self.assertEqual(chunks[3], rgb, atol=0, rtol=0)
self.assertEqual(chunks[4], bgr, atol=0, rtol=0)
self.assertEqual(chunks[5], gray, atol=1, rtol=0)
self.assertEqual(chunks[6], rgb16, atol=256, rtol=0)
self.assertEqual(chunks[7], argb, atol=0, rtol=0)
self.assertEqual(chunks[8], rgba, atol=0, rtol=0)
self.assertEqual(chunks[9], abgr, atol=0, rtol=0)
self.assertEqual(chunks[10], bgra, atol=0, rtol=0)
self.assertEqual(chunks[11], yuv16, atol=4, rtol=0)
@skipIfNoHWAccel("h264_cuvid")
class CuvidHWAccelInterfaceTest(TorchaudioTestCase):
def test_dup_hw_acel(self):
"""Specifying the same source stream with and without HW accel should fail (instead of segfault later)"""
src = get_asset_path("nasa_13013.mp4")
r = StreamReader(src)
r.add_video_stream(-1, decoder="h264_cuvid")
with self.assertRaises(RuntimeError):
r.add_video_stream(-1, decoder="h264_cuvid", hw_accel="cuda")
r = StreamReader(src)
r.add_video_stream(-1, decoder="h264_cuvid", hw_accel="cuda")
with self.assertRaises(RuntimeError):
r.add_video_stream(-1, decoder="h264_cuvid")
@_media_source
class CudaDecoderTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase):
def _test_decode(
self,
decoder: str,
src_path: str,
height: int,
width: int,
ref_num_frames: int,
hw_accel=None,
decoder_option=None,
dtype: torch.dtype = torch.uint8,
):
src = self.get_src(get_asset_path(src_path))
r = StreamReader(src)
r.add_video_stream(10, decoder=decoder, decoder_option=decoder_option, hw_accel=hw_accel)
num_frames = 0
for (chunk,) in r.stream():
self.assertEqual(chunk.device, torch.device(hw_accel or "cpu"))
self.assertEqual(chunk.dtype, dtype)
self.assertEqual(chunk.shape, torch.Size([10, 3, height, width]))
num_frames += chunk.size(0)
assert num_frames == ref_num_frames
@skipIfNoHWAccel("h264_cuvid")
def test_h264_cuvid(self):
"""GPU decoder works for H264"""
self._test_decode("h264_cuvid", "nasa_13013.mp4", 270, 480, 390)
@skipIfNoHWAccel("h264_cuvid")
def test_h264_cuvid_hw_accel(self):
"""GPU decoder works for H264 with HW acceleration, and put the frames on CUDA tensor"""
self._test_decode("h264_cuvid", "nasa_13013.mp4", 270, 480, 390, hw_accel="cuda:0")
@skipIfNoHWAccel("h264_cuvid")
def test_h264_cuvid_hw_accel_resize(self):
"""GPU decoder works for H264 with HW acceleration and resize option"""
w, h = 240, 136
self._test_decode(
"h264_cuvid", "nasa_13013.mp4", h, w, 390, hw_accel="cuda:0", decoder_option={"resize": f"{w}x{h}"}
)
@skipIfNoHWAccel("h264_cuvid")
def test_h264_cuvid_hw_accel_crop(self):
"""GPU decoder works for H264 with HW acceleration and crop option"""
top, bottom, left, right = 3, 5, 7, 9
self._test_decode(
"h264_cuvid",
"nasa_13013.mp4",
262,
464,
390,
hw_accel="cuda:0",
decoder_option={"crop": f"{top}x{bottom}x{left}x{right}"},
)
@skipIfNoHWAccel("hevc_cuvid")
def test_hevc_cuvid(self):
"""GPU decoder works for H265/HEVC"""
self._test_decode("hevc_cuvid", "testsrc.hevc", 144, 256, 300)
@skipIfNoHWAccel("hevc_cuvid")
def test_hevc_cuvid_hw_accel(self):
"""GPU decoder works for H265/HEVC with HW acceleration, and put the frames on CUDA tensor"""
self._test_decode("hevc_cuvid", "testsrc.hevc", 144, 256, 300, hw_accel="cuda:0", dtype=torch.int16)
@skipIfNoHWAccel("hevc_cuvid")
def test_hevc_cuvid_hw_accel_resize(self):
"""GPU decoder works for H265/HEVC with HW acceleration and resize option"""
w, h = 128, 64
self._test_decode(
"hevc_cuvid",
"testsrc.hevc",
h,
w,
300,
hw_accel="cuda:0",
dtype=torch.int16,
decoder_option={"resize": f"{w}x{h}"},
)
@skipIfNoHWAccel("hevc_cuvid")
def test_hevc_cuvid_hw_accel_crop(self):
"""GPU decoder works for H265/HEVC with HW acceleration and crop option"""
top, bottom, left, right = 3, 5, 7, 9
self._test_decode(
"hevc_cuvid",
"testsrc.hevc",
136,
240,
300,
hw_accel="cuda:0",
dtype=torch.int16,
decoder_option={"crop": f"{top}x{bottom}x{left}x{right}"},
)
@skipIfNoHWAccel("h264_cuvid")
# Disabled in CI: https://github.com/pytorch/audio/issues/3376
@disabledInCI
class FilterGraphWithCudaAccel(TorchaudioTestCase):
def test_sclae_cuda_change_size(self):
"""scale_cuda filter can be used when HW accel is on"""
src = get_asset_path("nasa_13013.mp4")
r = StreamReader(src)
r.add_video_stream(10, decoder="h264_cuvid", hw_accel="cuda", filter_desc="scale_cuda=iw/2:ih/2")
num_frames = 0
for (chunk,) in r.stream():
self.assertEqual(chunk.device, torch.device("cuda:0"))
self.assertEqual(chunk.dtype, torch.uint8)
self.assertEqual(chunk.shape, torch.Size([10, 3, 135, 240]))
num_frames += chunk.size(0)
assert num_frames == 390
def test_scale_cuda_format(self):
"""yuv444p format conversion should work"""
src = get_asset_path("nasa_13013.mp4")
r = StreamReader(src)
r.add_video_stream(10, decoder="h264_cuvid", hw_accel="cuda", filter_desc="scale_cuda=format=yuv444p")
num_frames = 0
for (chunk,) in r.stream():
self.assertEqual(chunk.device, torch.device("cuda:0"))
self.assertEqual(chunk.dtype, torch.uint8)
self.assertEqual(chunk.shape, torch.Size([10, 3, 270, 480]))
num_frames += chunk.size(0)
assert num_frames == 390
import io
import math
import torch import torch
import torchaudio import torchaudio
from parameterized import parameterized, parameterized_class from parameterized import parameterized, parameterized_class
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
get_asset_path, get_asset_path,
get_sinusoid,
is_ffmpeg_available, is_ffmpeg_available,
nested_params, nested_params,
rgb_to_yuv_ccir, rgb_to_yuv_ccir,
...@@ -13,8 +17,10 @@ from torchaudio_unittest.common_utils import ( ...@@ -13,8 +17,10 @@ from torchaudio_unittest.common_utils import (
TorchaudioTestCase, TorchaudioTestCase,
) )
from .common import lt42
if is_ffmpeg_available(): if is_ffmpeg_available():
from torchaudio.io import StreamReader, StreamWriter from torchaudio.io import CodecConfig, StreamReader, StreamWriter
def get_audio_chunk(fmt, sample_rate, num_channels): def get_audio_chunk(fmt, sample_rate, num_channels):
...@@ -87,9 +93,21 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -87,9 +93,21 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
def get_dst(self, path): def get_dst(self, path):
return super().get_dst(self.get_temp_path(path)) return super().get_dst(self.get_temp_path(path))
def get_buf(self, path): def test_unopened_error(self):
with open(self.get_temp_path(path), "rb") as fileobj: """If dst is not opened when attempting to write data, runtime error should be raised"""
return fileobj.read() path = self.get_dst("test.mp4")
s = StreamWriter(path, format="mp4")
s.set_metadata(metadata={"artist": "torchaudio", "title": self.id()})
s.add_audio_stream(sample_rate=16000, num_channels=2)
s.add_video_stream(frame_rate=30, width=16, height=16)
dummy = torch.zeros((3, 2))
with self.assertRaises(RuntimeError):
s.write_audio_chunk(0, dummy)
dummy = torch.zeros((3, 3, 16, 16))
with self.assertRaises(RuntimeError):
s.write_video_chunk(1, dummy)
@skipIfNoModule("tinytag") @skipIfNoModule("tinytag")
def test_metadata_overwrite(self): def test_metadata_overwrite(self):
...@@ -135,21 +153,26 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -135,21 +153,26 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
@parameterized.expand( @parameterized.expand(
[ [
("mp3", 8000, 1, "s32p", None), ("mp3", 8000, 1, None, "s32p", None),
("mp3", 16000, 2, "fltp", None), ("mp3", 16000, 2, None, "fltp", None),
("mp3", 44100, 1, "s16p", {"abr": "true"}), ("mp3", 44100, 1, None, "s16p", {"abr": "true"}),
("flac", 8000, 1, "s16", None), ("flac", 8000, 1, None, "s16", None),
("flac", 16000, 2, "s32", None), ("flac", 16000, 2, None, "s32", None),
("opus", 48000, 2, None, {"strict": "experimental"}), ("opus", 48000, 2, "opus", None, None),
("adts", 8000, 1, "fltp", None), # AAC format ("ogg", 48000, 2, "vorbis", None, None),
("adts", 8000, 1, None, "fltp", None), # AAC format
] ]
) )
def test_valid_audio_muxer_and_codecs(self, ext, sample_rate, num_channels, encoder_format, encoder_option): def test_valid_audio_muxer_and_codecs(
self, ext, sample_rate, num_channels, encoder, encoder_format, encoder_option
):
"""Tensor of various dtypes can be saved as given format.""" """Tensor of various dtypes can be saved as given format."""
path = self.get_dst(f"test.{ext}") path = self.get_dst(f"test.{ext}")
s = StreamWriter(path, format=ext) s = StreamWriter(path, format=ext)
s.set_metadata(metadata={"artist": "torchaudio", "title": self.id()}) s.set_metadata(metadata={"artist": "torchaudio", "title": self.id()})
s.add_audio_stream(sample_rate, num_channels, encoder_option=encoder_option, encoder_format=encoder_format) s.add_audio_stream(
sample_rate, num_channels, encoder=encoder, encoder_option=encoder_option, encoder_format=encoder_format
)
chunk = get_audio_chunk("flt", sample_rate, num_channels) chunk = get_audio_chunk("flt", sample_rate, num_channels)
with s.open(): with s.open():
...@@ -202,6 +225,19 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -202,6 +225,19 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
s.write_audio_chunk(0, audio) s.write_audio_chunk(0, audio)
s.write_video_chunk(1, video) s.write_video_chunk(1, video)
@skipIfNoFFmpeg
class StreamWriterCorrectnessTest(TempDirMixin, TorchaudioTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
torchaudio.utils.ffmpeg_utils.set_log_level(32)
@classmethod
def tearDownClass(cls):
torchaudio.utils.ffmpeg_utils.set_log_level(8)
super().tearDownClass()
@nested_params( @nested_params(
[ [
("gray8", "gray8"), ("gray8", "gray8"),
...@@ -227,16 +263,16 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -227,16 +263,16 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
chunk = torch.randint(low=0, high=255, size=src_size, dtype=torch.uint8) chunk = torch.randint(low=0, high=255, size=src_size, dtype=torch.uint8)
# Write data # Write data
dst = self.get_dst(filename) dst = self.get_temp_path(filename)
s = StreamWriter(dst, format="rawvideo") s = StreamWriter(dst, format="rawvideo")
s.add_video_stream(frame_rate, width, height, format=src_fmt, encoder_format=encoder_fmt) s.add_video_stream(frame_rate, width, height, format=src_fmt, encoder_format=encoder_fmt)
with s.open(): with s.open():
s.write_video_chunk(0, chunk) s.write_video_chunk(0, chunk)
# Fetch the written data # Fetch the written data
if self.test_fileobj: with open(dst, "rb") as fileobj:
dst.flush() buf = fileobj.read()
buf = self.get_buf(filename)
result = torch.frombuffer(buf, dtype=torch.uint8) result = torch.frombuffer(buf, dtype=torch.uint8)
if encoder_fmt.endswith("p"): if encoder_fmt.endswith("p"):
result = result.reshape(src_size) result = result.reshape(src_size)
...@@ -261,14 +297,12 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -261,14 +297,12 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
h, w = resolution h, w = resolution
# Write data # Write data
dst = self.get_dst(filename) dst = self.get_temp_path(filename)
s = torchaudio.io.StreamWriter(dst=dst, format=ext) s = torchaudio.io.StreamWriter(dst=dst, format=ext)
s.add_video_stream(frame_rate=framerate, height=h, width=w, format=format) s.add_video_stream(frame_rate=framerate, height=h, width=w, format=format)
chunk = torch.stack([torch.full((3, h, w), i, dtype=torch.uint8) for i in torch.linspace(0, 255, 256)]) chunk = torch.stack([torch.full((3, h, w), i, dtype=torch.uint8) for i in torch.linspace(0, 255, 256)])
with s.open(): with s.open():
s.write_video_chunk(0, chunk) s.write_video_chunk(0, chunk)
if self.test_fileobj:
dst.flush()
# Load data # Load data
s = torchaudio.io.StreamReader(src=self.get_temp_path(filename)) s = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
...@@ -293,30 +327,54 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -293,30 +327,54 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
pass pass
@nested_params( @nested_params(
["wav", "mp3", "flac"], ["wav", "flac"],
[8000, 16000, 44100], [8000, 16000, 44100],
[1, 2], [1, 2],
) )
def test_audio_num_frames(self, ext, sample_rate, num_channels): def test_audio_num_frames_lossless(self, ext, sample_rate, num_channels):
"""""" """Lossless format preserves the data"""
filename = f"test.{ext}" filename = f"test.{ext}"
data = get_sinusoid(sample_rate=sample_rate, n_channels=num_channels, dtype="int16", channels_first=False)
# Write data # Write data
dst = self.get_dst(filename) dst = self.get_temp_path(filename)
s = torchaudio.io.StreamWriter(dst=dst, format=ext) s = torchaudio.io.StreamWriter(dst=dst, format=ext)
s.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels) s.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels, format="s16")
with s.open():
s.write_audio_chunk(0, data)
freq = 300 # Load data
duration = 60 s = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
theta = torch.linspace(0, freq * 2 * 3.14 * duration, sample_rate * duration) s.add_audio_stream(-1)
if num_channels == 1: s.process_all_packets()
chunk = torch.sin(theta).unsqueeze(-1) (saved,) = s.pop_chunks()
else:
chunk = torch.stack([torch.sin(theta), torch.cos(theta)], dim=-1) self.assertEqual(saved, data)
@parameterized.expand(
[
("mp3", 1, 8000),
("mp3", 1, 16000),
("mp3", 1, 44100),
("mp3", 2, 8000),
("mp3", 2, 16000),
("mp3", 2, 44100),
("opus", 1, 48000),
]
)
def test_audio_num_frames_lossy(self, ext, num_channels, sample_rate):
"""Saving audio preserves the number of channels and frames"""
filename = f"test.{ext}"
data = get_sinusoid(sample_rate=sample_rate, n_channels=num_channels, channels_first=False)
# Write data
dst = self.get_temp_path(filename)
s = torchaudio.io.StreamWriter(dst=dst, format=ext)
s.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels)
with s.open(): with s.open():
s.write_audio_chunk(0, chunk) s.write_audio_chunk(0, data)
if self.test_fileobj:
dst.flush()
# Load data # Load data
s = torchaudio.io.StreamReader(src=self.get_temp_path(filename)) s = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
...@@ -324,9 +382,28 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -324,9 +382,28 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
s.process_all_packets() s.process_all_packets()
(saved,) = s.pop_chunks() (saved,) = s.pop_chunks()
assert saved.shape == chunk.shape # On 4.1 OPUS produces 48312 samples (extra 312)
if format in ["wav", "flac"]: # this has been fixed on 4.2+
self.assertEqual(saved, chunk) # TODO: issue warning if on 4.1?
if ext == "opus" and lt42():
return
self.assertEqual(saved.shape, data.shape)
def test_g722_sample_rate(self):
"""Encoding G.722 properly converts sample rate to 16k"""
filename = "test.g722"
sample_rate = 41000
data = get_sinusoid(sample_rate=sample_rate, n_channels=1, channels_first=False)
# write data
dst = self.get_temp_path(filename)
w = StreamWriter(dst, format="g722")
w.add_audio_stream(sample_rate=sample_rate, num_channels=1)
with w.open():
w.write_audio_chunk(0, data)
r = StreamReader(src=self.get_temp_path(filename))
self.assertEqual(r.get_src_stream_info(0).sample_rate, 16000)
def test_preserve_fps(self): def test_preserve_fps(self):
"""Decimal point frame rate is properly saved """Decimal point frame rate is properly saved
...@@ -339,16 +416,346 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -339,16 +416,346 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
width, height = 96, 128 width, height = 96, 128
# Write data # Write data
dst = self.get_dst(filename) dst = self.get_temp_path(filename)
writer = torchaudio.io.StreamWriter(dst=dst, format=ext) writer = torchaudio.io.StreamWriter(dst=dst, format=ext)
writer.add_video_stream(frame_rate=frame_rate, width=width, height=height) writer.add_video_stream(frame_rate=frame_rate, width=width, height=height)
video = torch.randint(256, (90, 3, height, width), dtype=torch.uint8) video = torch.randint(256, (90, 3, height, width), dtype=torch.uint8)
with writer.open(): with writer.open():
writer.write_video_chunk(0, video) writer.write_video_chunk(0, video)
if self.test_fileobj:
dst.flush()
# Load data # Load data
reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename)) reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
assert reader.get_src_stream_info(0).frame_rate == frame_rate assert reader.get_src_stream_info(0).frame_rate == frame_rate
def test_video_pts_increment(self):
"""PTS values increment by the inverse of frame rate"""
ext = "mp4"
num_frames = 256
filename = f"test.{ext}"
frame_rate = 5000 / 167
width, height = 96, 128
# Write data
dst = self.get_temp_path(filename)
writer = torchaudio.io.StreamWriter(dst=dst, format=ext)
writer.add_video_stream(frame_rate=frame_rate, width=width, height=height)
video = torch.randint(256, (num_frames, 3, height, width), dtype=torch.uint8)
with writer.open():
writer.write_video_chunk(0, video)
reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
reader.add_video_stream(1)
pts = [chunk.pts for (chunk,) in reader.stream()]
assert len(pts) == num_frames
for i, val in enumerate(pts):
expected = i / frame_rate
assert abs(val - expected) < 1e-10
def test_audio_pts_increment(self):
"""PTS values increment by the inverse of sample rate"""
ext = "wav"
filename = f"test.{ext}"
sample_rate = 8000
num_channels = 2
# Write data
dst = self.get_temp_path(filename)
writer = torchaudio.io.StreamWriter(dst=dst, format=ext)
writer.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels)
audio = get_sinusoid(sample_rate=sample_rate, n_channels=num_channels, channels_first=False)
num_frames = audio.size(0)
with writer.open():
writer.write_audio_chunk(0, audio)
reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
frames_per_chunk = sample_rate // 4
reader.add_audio_stream(frames_per_chunk, -1)
chunks = [chunk for (chunk,) in reader.stream()]
expected = num_frames // (frames_per_chunk)
assert len(chunks) == expected, f"Expected {expected} elements. Found {len(chunks)}"
num_samples = 0
for chunk in chunks:
expected = num_samples / sample_rate
num_samples += chunk.size(0)
print(chunk.pts, expected)
assert abs(chunk.pts - expected) < 1e-10
@parameterized.expand(
[
(10, 100),
(15, 150),
(24, 240),
(25, 200),
(30, 300),
(50, 500),
(60, 600),
# PTS value conversion involves float <-> int conversion, which can
# introduce rounding error.
# This test is a spot-check for popular 29.97 Hz
(30000 / 1001, 10010),
]
)
def test_video_pts_overwrite(self, frame_rate, num_frames):
"""Can overwrite PTS"""
ext = "mp4"
filename = f"test.{ext}"
width, height = 8, 8
# Write data
dst = self.get_temp_path(filename)
writer = torchaudio.io.StreamWriter(dst=dst, format=ext)
writer.add_video_stream(frame_rate=frame_rate, width=width, height=height)
video = torch.zeros((1, 3, height, width), dtype=torch.uint8)
reference_pts = []
with writer.open():
for i in range(num_frames):
pts = i / frame_rate
reference_pts.append(pts)
writer.write_video_chunk(0, video, pts)
reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
reader.add_video_stream(1)
pts = [chunk.pts for (chunk,) in reader.stream()]
assert len(pts) == len(reference_pts)
for val, ref in zip(pts, reference_pts):
# torch provides isclose, but we don't know if converting floats to tensor
# could introduce a descrepancy, so we compare floats and use math.isclose
# for that.
assert math.isclose(val, ref)
def test_codec_config(self):
"""Can successfully set configuration and write audio."""
ext = "mp3"
filename = f"test.{ext}"
sample_rate = 44100
num_channels = 2
# Write data
dst = self.get_temp_path(filename)
writer = torchaudio.io.StreamWriter(dst=dst, format=ext)
codec_config = CodecConfig(bit_rate=198_000, compression_level=3)
writer.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels, codec_config=codec_config)
audio = torch.zeros((8000, 2))
with writer.open():
writer.write_audio_chunk(0, audio)
def test_codec_config_bit_rate_output(self):
"""Increasing the specified bit rate yields a larger encoded output."""
ext = "mp3"
sample_rate = 44100
num_channels = 2
audio = torch.rand((8000, num_channels))
def write_audio(buffer, bit_rate):
writer = torchaudio.io.StreamWriter(dst=buffer, format=ext)
writer.add_audio_stream(
sample_rate=sample_rate,
num_channels=num_channels,
codec_config=CodecConfig(bit_rate=bit_rate),
)
with writer.open():
writer.write_audio_chunk(0, audio)
dst = io.BytesIO()
write_audio(dst, 198_000)
out0_size = dst.tell()
dst = io.BytesIO()
write_audio(dst, 320_000)
out1_size = dst.tell()
self.assertGreater(out1_size, out0_size)
def test_filter_graph_audio(self):
"""Can apply additional effect with filter graph"""
sample_rate = 8000
num_channels = 2
ext = "wav"
filename = f"test.{ext}"
original = get_audio_chunk("s16", num_channels=num_channels, sample_rate=sample_rate)
dst = self.get_temp_path(filename)
w = StreamWriter(dst, format=ext)
w.add_audio_stream(sample_rate=8000, num_channels=num_channels, filter_desc="areverse", format="s16")
with w.open():
w.write_audio_chunk(0, original)
reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
reader.add_audio_stream(-1)
reader.process_all_packets()
(output,) = reader.pop_chunks()
self.assertEqual(output, original.flip(0))
def test_filter_graph_video(self):
"""Can apply additional effect with filter graph"""
src_rate = 30
num_frames, width, height = 400, 160, 90
filter_desc = "framestep=2"
enc_rate = 15
ext = "mp4"
filename = f"test.{ext}"
original = torch.zeros((num_frames, 3, height, width), dtype=torch.uint8)
dst = self.get_temp_path(filename)
w = StreamWriter(dst, format=ext)
w.add_video_stream(
frame_rate=src_rate,
format="rgb24",
height=height,
width=width,
filter_desc=filter_desc,
encoder_format="yuv420p",
encoder_frame_rate=enc_rate,
)
with w.open():
w.write_video_chunk(0, original)
reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
reader.add_video_stream(-1)
reader.process_all_packets()
(output,) = reader.pop_chunks()
self.assertEqual(output.shape, [num_frames // 2, 3, height, width])
@parameterized.expand(
[
("wav", "pcm_s16le", 8000, 16000, 1, 2),
("wav", "pcm_s16le", 8000, 16000, 2, 1),
("wav", "pcm_s16le", 8000, 16000, 2, 4),
("wav", "pcm_s16le", 16000, 8000, 1, 2),
("wav", "pcm_s16le", 16000, 8000, 2, 1),
("wav", "pcm_s16le", 16000, 8000, 2, 4),
("wav", "pcm_f32le", 8000, 16000, 1, 2),
("wav", "pcm_f32le", 8000, 16000, 2, 1),
("wav", "pcm_f32le", 8000, 16000, 2, 4),
("wav", "pcm_f32le", 16000, 8000, 1, 2),
("wav", "pcm_f32le", 16000, 8000, 2, 1),
("wav", "pcm_f32le", 16000, 8000, 2, 4),
("ogg", "opus", 8000, 48000, 1, 2),
("ogg", "opus", 8000, 48000, 2, 1),
("ogg", "flac", 8000, 41000, 1, 2),
("ogg", "flac", 8000, 41000, 2, 1),
("ogg", "vorbis", 16000, 8000, 1, 2),
("ogg", "vorbis", 16000, 8000, 4, 2),
]
)
def test_change_audio_encoder_spec(self, ext, encoder, src_sr, enc_sr, src_num_channels, enc_num_channels):
"""Can change sample rate and channels on-the-fly"""
filename = f"test.{ext}"
original = get_sinusoid(sample_rate=src_sr, n_channels=src_num_channels, channels_first=False, duration=0.1)
dst = self.get_temp_path(filename)
w = StreamWriter(dst, format=ext)
w.add_audio_stream(
sample_rate=src_sr,
format="flt",
num_channels=src_num_channels,
encoder=encoder,
encoder_sample_rate=enc_sr,
encoder_num_channels=enc_num_channels,
)
with w.open():
w.write_audio_chunk(0, original)
# check
reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
i = reader.get_src_stream_info(0)
self.assertEqual(i.sample_rate, enc_sr)
self.assertEqual(i.num_channels, enc_num_channels)
@parameterized.expand(
[
# opus only supports 48kHz
("ogg", "opus", 8000, 48000, 1, 1),
("ogg", "opus", 16000, 48000, 2, 2),
# vorbis only supports 2 channels
("ogg", "vorbis", 16000, 16000, 1, 2),
("ogg", "vorbis", 16000, 16000, 2, 2),
("ogg", "vorbis", 16000, 16000, 4, 2),
]
)
def test_change_encoder_spec_default(
self, ext, encoder, src_sr, expected_sr, src_num_channels, expected_num_channels
):
"""If input rate/channels are not supported, encoder picks supported one automatically."""
filename = f"test.{ext}"
original = get_sinusoid(sample_rate=src_sr, n_channels=src_num_channels, channels_first=False, duration=0.1)
dst = self.get_temp_path(filename)
w = StreamWriter(dst, format=ext)
w.add_audio_stream(
sample_rate=src_sr,
format="flt",
num_channels=src_num_channels,
encoder=encoder,
)
with w.open():
w.write_audio_chunk(0, original)
# check
reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
i = reader.get_src_stream_info(0)
self.assertEqual(i.sample_rate, expected_sr)
self.assertEqual(i.num_channels, expected_num_channels)
@parameterized.expand(
[
("mp4", None, 10, 30, (100, 160), (200, 320)),
("mp4", None, 10, 30, (100, 160), (50, 80)),
("mp4", None, 30, 10, (100, 160), (200, 320)),
("mp4", None, 30, 10, (100, 160), (50, 80)),
]
)
def test_change_video_encoder_spec(self, ext, encoder, src_rate, enc_rate, src_size, enc_size):
"""Can change the frame rate and image size on-the-fly"""
width, height = src_size
enc_width, enc_height = enc_size
ext = "mp4"
filename = f"test.{ext}"
num_frames = 256
original = torch.zeros((num_frames, 3, height, width), dtype=torch.uint8)
dst = self.get_temp_path(filename)
w = StreamWriter(dst, format=ext)
w.add_video_stream(
frame_rate=src_rate,
format="rgb24",
height=height,
width=width,
encoder_format="yuv420p",
encoder_frame_rate=enc_rate,
encoder_width=enc_width,
encoder_height=enc_height,
)
with w.open():
w.write_video_chunk(0, original)
# check
reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
i = reader.get_src_stream_info(0)
self.assertEqual(i.frame_rate, enc_rate)
self.assertEqual(i.width, enc_width)
self.assertEqual(i.height, enc_height)
...@@ -169,3 +169,19 @@ class CTCDecoderTest(TempDirMixin, TorchaudioTestCase): ...@@ -169,3 +169,19 @@ class CTCDecoderTest(TempDirMixin, TorchaudioTestCase):
expected_tokens = ["|", "f", "|", "o", "a"] expected_tokens = ["|", "f", "|", "o", "a"]
self.assertEqual(tokens, expected_tokens) self.assertEqual(tokens, expected_tokens)
def test_lm_lifecycle(self):
"""Passing lm without assiging it to a vaiable won't cause runtime error
https://github.com/pytorch/audio/issues/3218
"""
from torchaudio.models.decoder import ctc_decoder
from .ctc_decoder_utils import CustomZeroLM
decoder = ctc_decoder(
lexicon=get_asset_path("decoder/lexicon.txt"),
tokens=get_asset_path("decoder/tokens.txt"),
lm=CustomZeroLM(),
)
decoder(torch.zeros((1, 3, NUM_TOKENS), dtype=torch.float32))
import torch
from torchaudio_unittest.common_utils import (
get_asset_path,
skipIfNoCuCtcDecoder,
skipIfNoCuda,
TempDirMixin,
TorchaudioTestCase,
)
NUM_TOKENS = 7
@skipIfNoCuda
@skipIfNoCuCtcDecoder
class CUCTCDecoderTest(TempDirMixin, TorchaudioTestCase):
def _get_decoder(self, tokens=None, **kwargs):
from torchaudio.models.decoder import cuda_ctc_decoder
if tokens is None:
tokens = get_asset_path("decoder/tokens.txt")
return cuda_ctc_decoder(
tokens=tokens,
beam_size=5,
**kwargs,
)
def _get_emissions(self):
B, T, N = 4, 15, NUM_TOKENS
emissions = torch.rand(B, T, N).cuda()
emissions = torch.nn.functional.log_softmax(emissions, -1)
return emissions
def test_construct_basic_decoder_path(self):
tokens_path = get_asset_path("decoder/tokens.txt")
self._get_decoder(tokens=tokens_path)
def test_construct_basic_decoder_tokens(self):
tokens = ["-", "|", "f", "o", "b", "a", "r"]
self._get_decoder(tokens=tokens)
def test_shape(self):
log_probs = self._get_emissions()
encoder_out_lens = torch.tensor([15, 14, 13, 12], dtype=torch.int32).cuda()
decoder = self._get_decoder()
results = decoder(log_probs, encoder_out_lens)
self.assertEqual(len(results), log_probs.shape[0])
...@@ -99,7 +99,7 @@ class RNNTBeamSearchTestImpl(TestBaseMixin): ...@@ -99,7 +99,7 @@ class RNNTBeamSearchTestImpl(TestBaseMixin):
self.assertEqual(res, scripted_res) self.assertEqual(res, scripted_res)
state = res[1] state = res[1]
hypo = res[0][0] hypo = res[0]
scripted_state = scripted_res[1] scripted_state = scripted_res[1]
scripted_hypo = scripted_res[0][0] scripted_hypo = scripted_res[0]
import torch
from parameterized import parameterized
from torchaudio.models import squim_objective_base, squim_subjective_base
from torchaudio_unittest.common_utils import skipIfNoCuda, torch_script, TorchaudioTestCase
class TestSquimObjective(TorchaudioTestCase):
def _smoke_test_objective(self, model, device, dtype):
model = model.to(device=device, dtype=dtype)
model = model.eval()
batch_size, num_frames = 3, 16000
waveforms = torch.randn(batch_size, num_frames, device=device, dtype=dtype)
model(waveforms)
@parameterized.expand([(torch.float32,), (torch.float64,)])
def test_cpu_smoke_test(self, dtype):
model = squim_objective_base()
self._smoke_test_objective(model, torch.device("cpu"), dtype)
@parameterized.expand([(torch.float32,), (torch.float64,)])
@skipIfNoCuda
def test_cuda_smoke_test(self, dtype):
model = squim_objective_base()
self._smoke_test_objective(model, torch.device("cuda"), dtype)
def test_batch_consistency(self):
model = squim_objective_base()
model.eval()
batch_size, num_frames = 3, 16000
waveforms = torch.randn(batch_size, num_frames)
ref_scores = model(waveforms)
hyp_scores = [torch.zeros(batch_size), torch.zeros(batch_size), torch.zeros(batch_size)]
for i in range(batch_size):
scores = model(waveforms[i : i + 1])
for j in range(3):
hyp_scores[j][i] = scores[j]
self.assertEqual(len(hyp_scores), len(ref_scores))
for i in range(len(ref_scores)):
self.assertEqual(hyp_scores[i], ref_scores[i])
def test_torchscript_consistency(self):
model = squim_objective_base()
model.eval()
batch_size, num_frames = 3, 16000
waveforms = torch.randn(batch_size, num_frames)
ref_scores = model(waveforms)
scripted = torch_script(model)
hyp_scores = scripted(waveforms)
self.assertEqual(len(hyp_scores), len(ref_scores))
for i in range(len(ref_scores)):
self.assertEqual(hyp_scores[i], ref_scores[i])
class TestSquimSubjective(TorchaudioTestCase):
def _smoke_test_subjective(self, model, device, dtype):
model = model.to(device=device, dtype=dtype)
model = model.eval()
batch_size, num_frames = 3, 16000
waveforms = torch.randn(batch_size, num_frames, device=device, dtype=dtype)
reference = torch.randn(batch_size, num_frames, device=device, dtype=dtype)
model(waveforms, reference)
@parameterized.expand([(torch.float32,), (torch.float64,)])
def test_cpu_smoke_test(self, dtype):
model = squim_subjective_base()
self._smoke_test_subjective(model, torch.device("cpu"), dtype)
@parameterized.expand([(torch.float32,), (torch.float64,)])
@skipIfNoCuda
def test_cuda_smoke_test(self, dtype):
model = squim_subjective_base()
self._smoke_test_subjective(model, torch.device("cuda"), dtype)
def test_batch_consistency(self):
model = squim_subjective_base()
model.eval()
batch_size, num_frames = 3, 16000
waveforms = torch.randn(batch_size, num_frames)
reference = torch.randn(batch_size, num_frames)
ref_scores = model(waveforms, reference)
hyp_scores = []
for i in range(batch_size):
scores = model(waveforms[i : i + 1], reference[i : i + 1])
hyp_scores.append(scores)
hyp_scores = torch.tensor(hyp_scores)
self.assertEqual(hyp_scores, ref_scores)
def test_torchscript_consistency(self):
model = squim_subjective_base()
model.eval()
batch_size, num_frames = 3, 16000
waveforms = torch.randn(batch_size, num_frames)
reference = torch.randn(batch_size, num_frames)
ref_scores = model(waveforms, reference)
scripted = torch_script(model)
hyp_scores = scripted(waveforms, reference)
self.assertEqual(hyp_scores, ref_scores)
...@@ -42,7 +42,7 @@ class TorchscriptConsistencyMixin(TestBaseMixin): ...@@ -42,7 +42,7 @@ class TorchscriptConsistencyMixin(TestBaseMixin):
class Tacotron2EncoderTests(TorchscriptConsistencyMixin): class Tacotron2EncoderTests(TorchscriptConsistencyMixin):
@skipIfPy310 # @skipIfPy310
def test_tacotron2_torchscript_consistency(self): def test_tacotron2_torchscript_consistency(self):
r"""Validate the torchscript consistency of a Encoder.""" r"""Validate the torchscript consistency of a Encoder."""
n_batch, n_seq, encoder_embedding_dim = 16, 64, 512 n_batch, n_seq, encoder_embedding_dim = 16, 64, 512
...@@ -266,7 +266,7 @@ class Tacotron2Tests(TorchscriptConsistencyMixin): ...@@ -266,7 +266,7 @@ class Tacotron2Tests(TorchscriptConsistencyMixin):
(16,), (16,),
] ]
) )
@skipIfPy310 # @skipIfPy310
def test_tacotron2_torchscript_consistency(self, n_batch): def test_tacotron2_torchscript_consistency(self, n_batch):
r"""Validate the torchscript consistency of a Tacotron2.""" r"""Validate the torchscript consistency of a Tacotron2."""
n_mels = 80 n_mels = 80
...@@ -335,7 +335,7 @@ class Tacotron2Tests(TorchscriptConsistencyMixin): ...@@ -335,7 +335,7 @@ class Tacotron2Tests(TorchscriptConsistencyMixin):
(16,), (16,),
] ]
) )
@skipIfPy310 # @skipIfPy310
def test_tacotron2_inference_torchscript_consistency(self, n_batch): def test_tacotron2_inference_torchscript_consistency(self, n_batch):
r"""Validate the torchscript consistency of Tacotron2 inference function.""" r"""Validate the torchscript consistency of Tacotron2 inference function."""
n_mels = 40 n_mels = 40
......
...@@ -9,9 +9,12 @@ from torchaudio.models.wav2vec2 import ( ...@@ -9,9 +9,12 @@ from torchaudio.models.wav2vec2 import (
wav2vec2_base, wav2vec2_base,
wav2vec2_large, wav2vec2_large,
wav2vec2_large_lv60k, wav2vec2_large_lv60k,
wav2vec2_xlsr_1b,
wav2vec2_xlsr_2b,
wav2vec2_xlsr_300m,
) )
from torchaudio.models.wav2vec2.utils import import_fairseq_model from torchaudio.models.wav2vec2.utils import import_fairseq_model
from torchaudio_unittest.common_utils import get_asset_path, skipIfNoModule, TorchaudioTestCase from torchaudio_unittest.common_utils import get_asset_path, skipIfCudaSmallMemory, skipIfNoModule, TorchaudioTestCase
def _load_config(*paths): def _load_config(*paths):
...@@ -31,6 +34,9 @@ WAV2VEC2_XLSR_53_56K = _load_config("xlsr_53_56k") ...@@ -31,6 +34,9 @@ WAV2VEC2_XLSR_53_56K = _load_config("xlsr_53_56k")
HUBERT_BASE = _load_config("hubert_base_ls960") HUBERT_BASE = _load_config("hubert_base_ls960")
HUBERT_LARGE_LL60K = _load_config("hubert_large_ll60k") HUBERT_LARGE_LL60K = _load_config("hubert_large_ll60k")
HUBERT_XLARGE_LL60K = _load_config("hubert_xtralarge_ll60k") HUBERT_XLARGE_LL60K = _load_config("hubert_xtralarge_ll60k")
WAV2VEC2_XLSR_300M = _load_config("xlsr_300m")
WAV2VEC2_XLSR_1B = _load_config("xlsr_1b")
WAV2VEC2_XLSR_2B = _load_config("xlsr_2b")
# Finetuning models # Finetuning models
WAV2VEC2_BASE_960H = _load_config("wav2vec_small_960h") WAV2VEC2_BASE_960H = _load_config("wav2vec_small_960h")
WAV2VEC2_LARGE_960H = _load_config("wav2vec_large_960h") WAV2VEC2_LARGE_960H = _load_config("wav2vec_large_960h")
...@@ -50,6 +56,14 @@ WAV2VEC2_PRETRAINING_CONFIGS = parameterized.expand( ...@@ -50,6 +56,14 @@ WAV2VEC2_PRETRAINING_CONFIGS = parameterized.expand(
], ],
name_func=_name_func, name_func=_name_func,
) )
XLSR_PRETRAINING_CONFIGS = parameterized.expand(
[
(WAV2VEC2_XLSR_300M, wav2vec2_xlsr_300m),
(WAV2VEC2_XLSR_1B, wav2vec2_xlsr_1b),
(WAV2VEC2_XLSR_2B, wav2vec2_xlsr_2b),
],
name_func=_name_func,
)
HUBERT_PRETRAINING_CONFIGS = parameterized.expand( HUBERT_PRETRAINING_CONFIGS = parameterized.expand(
[ [
(HUBERT_BASE, hubert_base), (HUBERT_BASE, hubert_base),
...@@ -134,7 +148,24 @@ class TestFairseqIntegration(TorchaudioTestCase): ...@@ -134,7 +148,24 @@ class TestFairseqIntegration(TorchaudioTestCase):
hyp, _ = imported.extract_features(x) hyp, _ = imported.extract_features(x)
refs = original.extract_features(x, padding_mask=torch.zeros_like(x), layer=-1) refs = original.extract_features(x, padding_mask=torch.zeros_like(x), layer=-1)
for i, (ref, _) in enumerate(refs["layer_results"]): for i, (ref, _) in enumerate(refs["layer_results"]):
self.assertEqual(hyp[i], ref.transpose(0, 1)) self.assertEqual(hyp[i], ref.transpose(0, 1), atol=1.5e-5, rtol=1.3e-6)
@XLSR_PRETRAINING_CONFIGS
@skipIfCudaSmallMemory
def test_import_xlsr_pretraining_model(self, config, factory_func):
"""XLS-R pretraining models from fairseq can be imported and yields the same results"""
batch_size, num_frames = 3, 1024
original = self._get_model(config).eval()
imported = import_fairseq_model(original).eval()
x = torch.randn(batch_size, num_frames)
hyp, _ = imported.extract_features(x)
refs = original.extract_features(x, padding_mask=torch.zeros_like(x), layer=-1)
for i, (ref, _) in enumerate(refs["layer_results"]):
# There is one element whose difference is over 1e-5 in wav2vec2_xlsr_1b and wav2vec2_xlsr_2b.
atol = 1.0e-05 if factory_func is wav2vec2_xlsr_300m else 1e-4
self.assertEqual(hyp[i], ref.transpose(0, 1), atol=atol, rtol=1.3e-6)
@HUBERT_PRETRAINING_CONFIGS @HUBERT_PRETRAINING_CONFIGS
def test_import_hubert_pretraining_model(self, config, factory_func): def test_import_hubert_pretraining_model(self, config, factory_func):
...@@ -150,15 +181,13 @@ class TestFairseqIntegration(TorchaudioTestCase): ...@@ -150,15 +181,13 @@ class TestFairseqIntegration(TorchaudioTestCase):
# check the last layer # check the last layer
ref, _ = original.extract_features(x, padding_mask=mask, output_layer=len(original.encoder.layers)) ref, _ = original.extract_features(x, padding_mask=mask, output_layer=len(original.encoder.layers))
atol = 3.0e-05 if factory_func is hubert_xlarge else 1.0e-5 self.assertEqual(hyp[-1], ref, atol=3.0e-5, rtol=1.3e-6)
self.assertEqual(hyp[-1], ref, atol=atol, rtol=1.3e-6)
# check the first layer # check the first layer
ref, _ = original.extract_features(x, padding_mask=mask, output_layer=1) ref, _ = original.extract_features(x, padding_mask=mask, output_layer=1)
self.assertEqual(hyp[0], ref) self.assertEqual(hyp[0], ref)
@ALL_PRETRAINING_CONFIGS def _test_recreate_pretraining_model(self, config, factory_func):
def test_recreate_pretraining_model(self, config, factory_func):
"""Imported pretraining models can be recreated via a factory function without fairseq.""" """Imported pretraining models can be recreated via a factory function without fairseq."""
batch_size, num_frames = 3, 1024 batch_size, num_frames = 3, 1024
...@@ -188,6 +217,15 @@ class TestFairseqIntegration(TorchaudioTestCase): ...@@ -188,6 +217,15 @@ class TestFairseqIntegration(TorchaudioTestCase):
self.assertEqual(ref, hyp) self.assertEqual(ref, hyp)
self.assertEqual(ref_lengths, hyp_lengths) self.assertEqual(ref_lengths, hyp_lengths)
@ALL_PRETRAINING_CONFIGS
def test_wav2vec2_recreate_pretraining_model(self, config, factory_func):
self._test_recreate_pretraining_model(config, factory_func)
@XLSR_PRETRAINING_CONFIGS
@skipIfCudaSmallMemory
def test_xlsr_recreate_pretraining_model(self, config, factory_func):
self._test_recreate_pretraining_model(config, factory_func)
@FINETUNING_CONFIGS @FINETUNING_CONFIGS
def test_import_finetuning_model(self, config, _): def test_import_finetuning_model(self, config, _):
"""Fintuned wav2vec2 models from fairseq can be imported and yields the same results""" """Fintuned wav2vec2 models from fairseq can be imported and yields the same results"""
......
import json import json
import unittest
import torch import torch
from parameterized import parameterized from parameterized import parameterized
from torchaudio.models.wav2vec2 import wav2vec2_base, wav2vec2_large, wav2vec2_large_lv60k from torchaudio.models.wav2vec2 import (
wav2vec2_base,
wav2vec2_large,
wav2vec2_large_lv60k,
wav2vec2_xlsr_1b,
wav2vec2_xlsr_2b,
wav2vec2_xlsr_300m,
wavlm_base,
wavlm_large,
)
from torchaudio.models.wav2vec2.utils import import_huggingface_model from torchaudio.models.wav2vec2.utils import import_huggingface_model
from torchaudio_unittest.common_utils import get_asset_path, skipIfNoModule, TorchaudioTestCase from torchaudio_unittest.common_utils import (
get_asset_path,
skipIfCudaSmallMemory,
skipIfNoModule,
TorchaudioTestCase,
zip_equal,
)
def _load_config(*paths): def _load_config(*paths):
...@@ -22,6 +38,11 @@ HF_LARGE = _load_config("wav2vec2-large") ...@@ -22,6 +38,11 @@ HF_LARGE = _load_config("wav2vec2-large")
HF_LARGE_LV60 = _load_config("wav2vec2-large-lv60") HF_LARGE_LV60 = _load_config("wav2vec2-large-lv60")
HF_LARGE_XLSR_53 = _load_config("wav2vec2-large-xlsr-53") HF_LARGE_XLSR_53 = _load_config("wav2vec2-large-xlsr-53")
HF_BASE_10K_VOXPOPULI = _load_config("wav2vec2-base-10k-voxpopuli") HF_BASE_10K_VOXPOPULI = _load_config("wav2vec2-base-10k-voxpopuli")
HF_BASE_WAVLM = _load_config("wavlm-base")
HF_LARGE_WAVLM = _load_config("wavlm-large")
HF_XLSR_300M = _load_config("wav2vec2-xls-r-300m")
HF_XLSR_1B = _load_config("wav2vec2-xls-r-1b")
HF_XLSR_2B = _load_config("wav2vec2-xls-r-2b")
# Finetuned # Finetuned
HF_BASE_960H = _load_config("wav2vec2-base-960h") HF_BASE_960H = _load_config("wav2vec2-base-960h")
HF_LARGE_960H = _load_config("wav2vec2-large-960h") HF_LARGE_960H = _load_config("wav2vec2-large-960h")
...@@ -40,6 +61,14 @@ PRETRAIN_CONFIGS = parameterized.expand( ...@@ -40,6 +61,14 @@ PRETRAIN_CONFIGS = parameterized.expand(
], ],
name_func=_name_func, name_func=_name_func,
) )
XLSR_PRETRAIN_CONFIGS = parameterized.expand(
[
(HF_XLSR_300M, wav2vec2_xlsr_300m),
(HF_XLSR_1B, wav2vec2_xlsr_1b),
(HF_XLSR_2B, wav2vec2_xlsr_2b),
],
name_func=_name_func,
)
FINETUNE_CONFIGS = parameterized.expand( FINETUNE_CONFIGS = parameterized.expand(
[ [
(HF_BASE_960H, wav2vec2_base), (HF_BASE_960H, wav2vec2_base),
...@@ -50,8 +79,16 @@ FINETUNE_CONFIGS = parameterized.expand( ...@@ -50,8 +79,16 @@ FINETUNE_CONFIGS = parameterized.expand(
], ],
name_func=_name_func, name_func=_name_func,
) )
WAVLM_CONFIGS = parameterized.expand(
[
(HF_BASE_WAVLM, wavlm_base),
(HF_LARGE_WAVLM, wavlm_large),
],
name_func=_name_func,
)
@unittest.skip("transformers v4.30 seems to break the weight format. See https://github.com/pytorch/audio/issues/3430")
@skipIfNoModule("transformers") @skipIfNoModule("transformers")
class TestHFIntegration(TorchaudioTestCase): class TestHFIntegration(TorchaudioTestCase):
"""Test the process of importing the models from Hugging Face Transformers """Test the process of importing the models from Hugging Face Transformers
...@@ -68,12 +105,14 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -68,12 +105,14 @@ class TestHFIntegration(TorchaudioTestCase):
# However, somehow, once "transformers" is imported, `is_module_available` # However, somehow, once "transformers" is imported, `is_module_available`
# starts to fail. Therefore, we defer importing "transformers" until # starts to fail. Therefore, we defer importing "transformers" until
# the actual tests are started. # the actual tests are started.
from transformers.models.wav2vec2 import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2Model from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2Model, WavLMConfig, WavLMModel
if config["architectures"] == ["Wav2Vec2Model"]: if config["architectures"] == ["Wav2Vec2Model"]:
return Wav2Vec2Model(Wav2Vec2Config(**config)) return Wav2Vec2Model(Wav2Vec2Config(**config))
if config["architectures"] == ["Wav2Vec2ForCTC"]: if config["architectures"] == ["Wav2Vec2ForCTC"]:
return Wav2Vec2ForCTC(Wav2Vec2Config(**config)) return Wav2Vec2ForCTC(Wav2Vec2Config(**config))
if config["architectures"] == ["WavLMModel"]:
return WavLMModel(WavLMConfig(**config))
raise ValueError(f'Unexpected arch: {config["architectures"]}') raise ValueError(f'Unexpected arch: {config["architectures"]}')
def _test_import_pretrain(self, original, imported, config): def _test_import_pretrain(self, original, imported, config):
...@@ -97,9 +136,8 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -97,9 +136,8 @@ class TestHFIntegration(TorchaudioTestCase):
b, l, e = 16, 3, config["hidden_size"] b, l, e = 16, 3, config["hidden_size"]
x = torch.randn(b, l, e) x = torch.randn(b, l, e)
mask = torch.randn(b, 1, l, l) mask = torch.randn(b, 1, l, l)
(ref,) = original_(x, attention_mask=mask, output_attentions=False) (ref,) = original_(x, attention_mask=mask, output_attentions=False)
hyp = imported_(x, mask) hyp, _ = imported_(x, mask) # Ignore returned position_bias, which is always None for Wav2Vec2 and HuBERT
self.assertEqual(ref, hyp) self.assertEqual(ref, hyp)
# The whole Encoder Transformer # The whole Encoder Transformer
b, l, e = 16, 3, config["hidden_size"] b, l, e = 16, 3, config["hidden_size"]
...@@ -115,11 +153,6 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -115,11 +153,6 @@ class TestHFIntegration(TorchaudioTestCase):
hyp = imported.aux(x) hyp = imported.aux(x)
self.assertEqual(ref, hyp) self.assertEqual(ref, hyp)
# The whole model without mask # The whole model without mask
x = torch.randn(3, 1024)
ref = original(x).logits
hyp, _ = imported(x)
self.assertEqual(ref, hyp)
# The whole model without mask
batch_size, num_frames = 3, 1024 batch_size, num_frames = 3, 1024
x = torch.randn(batch_size, num_frames) x = torch.randn(batch_size, num_frames)
ref = original(x).logits ref = original(x).logits
...@@ -151,6 +184,14 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -151,6 +184,14 @@ class TestHFIntegration(TorchaudioTestCase):
imported = import_huggingface_model(original).eval() imported = import_huggingface_model(original).eval()
self._test_import_pretrain(original, imported, config) self._test_import_pretrain(original, imported, config)
@XLSR_PRETRAIN_CONFIGS
@skipIfCudaSmallMemory
def test_import_xlsr_pretrain(self, config, _):
"""XLS-R models from HF transformers can be imported and yields the same results"""
original = self._get_model(config).eval()
imported = import_huggingface_model(original).eval()
self._test_import_pretrain(original, imported, config)
@FINETUNE_CONFIGS @FINETUNE_CONFIGS
def test_import_finetune(self, config, _): def test_import_finetune(self, config, _):
"""wav2vec2 models from HF transformers can be imported and yields the same results""" """wav2vec2 models from HF transformers can be imported and yields the same results"""
...@@ -159,6 +200,51 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -159,6 +200,51 @@ class TestHFIntegration(TorchaudioTestCase):
self._test_import_pretrain(original.wav2vec2, imported, config) self._test_import_pretrain(original.wav2vec2, imported, config)
self._test_import_finetune(original, imported, config) self._test_import_finetune(original, imported, config)
@WAVLM_CONFIGS
def test_import_pretrain_wavlm(self, config, _):
"""WavLM models from HF transformers can be imported and yield the same results"""
original = self._get_model(config).eval()
imported = import_huggingface_model(original).eval()
# FeatureExtractor
x = torch.randn(3, 1024)
ref = original.feature_extractor(x).transpose(1, 2)
hyp, _ = imported.feature_extractor(x, None)
self.assertEqual(ref, hyp)
# Feature projection
x = torch.randn(3, 10, config["conv_dim"][-1])
ref = original.feature_projection(x)[0]
hyp = imported.encoder.feature_projection(x)
self.assertEqual(ref, hyp)
# Convolutional Positional Encoder
x = torch.randn(3, 256, config["hidden_size"])
ref = original.encoder.pos_conv_embed(x)
hyp = imported.encoder.transformer.pos_conv_embed(x)
self.assertEqual(ref, hyp)
position_bias = None
position_bias_imp = None
assert len(original.encoder.layers) > 0
for original_, imported_ in zip_equal(original.encoder.layers, imported.encoder.transformer.layers):
b, l, e = 16, 3, config["hidden_size"]
x = torch.randn(b, l, e)
mask = torch.randn(b, l) > 0.5 # HF WaveLM model expects the mask to be binary
# HF WaveLM model (original_) takes in "attention mask" but actually uses it as key padding mask:
# https://github.com/huggingface/transformers/blob/b047472650cba259621549ac27b18fd2066ce18e/src/transformers/models/wavlm/modeling_wavlm.py#L495
ref, position_bias = original_(x, attention_mask=mask, output_attentions=False, position_bias=position_bias)
hyp, position_bias_imp = imported_(x, key_padding_mask=mask.ne(1), position_bias=position_bias_imp)
# Masked-out elements are undefined in the output
ref_filled = ref.masked_fill(~mask.unsqueeze(2), 0)
hyp_filled = hyp.masked_fill(~mask.unsqueeze(2), 0)
self.assertEqual(ref_filled, hyp_filled)
# The whole Encoder Transformer
b, l, e = 16, 3, config["hidden_size"]
x = torch.randn(b, l, e)
ref = original.encoder(x).last_hidden_state
hyp = imported.encoder.transformer(x)
self.assertEqual(ref, hyp)
def _test_recreate(self, imported, reloaded, config): def _test_recreate(self, imported, reloaded, config):
# FeatureExtractor # FeatureExtractor
x = torch.randn(3, 1024) x = torch.randn(3, 1024)
...@@ -221,3 +307,50 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -221,3 +307,50 @@ class TestHFIntegration(TorchaudioTestCase):
reloaded.load_state_dict(imported.state_dict()) reloaded.load_state_dict(imported.state_dict())
reloaded.eval() reloaded.eval()
self._test_recreate(imported, reloaded, config) self._test_recreate(imported, reloaded, config)
@WAVLM_CONFIGS
def test_recreate_wavlm(self, config, factory_func):
"""Imported models can be recreated via a factory function without Hugging Face transformers."""
imported = import_huggingface_model(self._get_model(config)).eval()
reloaded = factory_func()
reloaded.load_state_dict(imported.state_dict())
reloaded.eval()
# FeatureExtractor
x = torch.randn(3, 1024)
ref, _ = imported.feature_extractor(x, None)
hyp, _ = reloaded.feature_extractor(x, None)
self.assertEqual(ref, hyp)
# Feature projection
x = torch.randn(3, 10, config["conv_dim"][-1])
ref = imported.encoder.feature_projection(x)
hyp = reloaded.encoder.feature_projection(x)
self.assertEqual(ref, hyp)
# Convolutional Positional Encoder
x = torch.randn(3, 256, config["hidden_size"])
ref = imported.encoder.transformer.pos_conv_embed(x)
hyp = reloaded.encoder.transformer.pos_conv_embed(x)
self.assertEqual(ref, hyp)
# Encoder Transformer Layer
position_bias_ref = None
position_bias_hyp = None
for imported_, reloaded_ in zip(imported.encoder.transformer.layers, reloaded.encoder.transformer.layers):
b, l, e = 16, 3, config["hidden_size"]
x = torch.randn(b, l, e)
mask = torch.randn(b, l) > 0.5 # HugginFace WaveLM expects the mask to be binary
ref, position_bias_ref = imported_(x, key_padding_mask=mask, position_bias=position_bias_ref)
hyp, position_bias_hyp = reloaded_(x, key_padding_mask=mask, position_bias=position_bias_hyp)
self.assertEqual(ref, hyp)
# The whole Encoder Transformer
# TODO: Add mask pattern. Expected mask shapes and values are different.
b, l, e = 16, 3, config["hidden_size"]
x = torch.randn(b, l, e)
mask = torch.randn(b, 1, l, l)
ref = imported.encoder.transformer(x)
hyp = reloaded.encoder.transformer(x)
self.assertEqual(ref, hyp)
# The whole model
x = torch.randn(3, 1024)
ref, _ = imported(x)
hyp, _ = reloaded(x)
self.assertEqual(ref, hyp)
...@@ -15,6 +15,8 @@ from torchaudio.models.wav2vec2 import ( ...@@ -15,6 +15,8 @@ from torchaudio.models.wav2vec2 import (
wav2vec2_base, wav2vec2_base,
wav2vec2_large, wav2vec2_large,
wav2vec2_large_lv60k, wav2vec2_large_lv60k,
wavlm_base,
wavlm_large,
) )
from torchaudio_unittest.common_utils import skipIfNoCuda, skipIfNoQengine, torch_script, TorchaudioTestCase from torchaudio_unittest.common_utils import skipIfNoCuda, skipIfNoQengine, torch_script, TorchaudioTestCase
...@@ -41,6 +43,14 @@ factory_funcs = parameterized.expand( ...@@ -41,6 +43,14 @@ factory_funcs = parameterized.expand(
name_func=_name_func, name_func=_name_func,
) )
factory_funcs_wavlm = parameterized.expand(
[
(wavlm_base,),
(wavlm_large,),
],
name_func=_name_func,
)
factory_funcs_hubert_pretrain = parameterized.expand( factory_funcs_hubert_pretrain = parameterized.expand(
[ [
(hubert_pretrain_base,), (hubert_pretrain_base,),
...@@ -278,6 +288,131 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -278,6 +288,131 @@ class TestWav2Vec2Model(TorchaudioTestCase):
self._test_quantize_torchscript(factory_func(aux_num_out=32)) self._test_quantize_torchscript(factory_func(aux_num_out=32))
class TestWavLMModel(TorchaudioTestCase):
def _smoke_test(self, model, device, dtype):
model = model.to(device=device, dtype=dtype)
model = model.eval()
batch_size, num_frames = 3, 1024
waveforms = torch.randn(batch_size, num_frames, device=device, dtype=dtype)
model(waveforms)
@parameterized.expand([(torch.float32,), (torch.float64,)])
def test_cpu_smoke_test(self, dtype):
model = wavlm_base()
self._smoke_test(model, torch.device("cpu"), dtype)
model = wavlm_base(aux_num_out=32)
self._smoke_test(model, torch.device("cpu"), dtype)
@parameterized.expand([(torch.float32,), (torch.float64,)])
@skipIfNoCuda
def test_cuda_smoke_test(self, dtype):
model = wavlm_base()
self._smoke_test(model, torch.device("cuda"), dtype)
model = wavlm_base(aux_num_out=32)
self._smoke_test(model, torch.device("cuda"), dtype)
def _test_batch_consistency(self, model):
model.eval()
batch_size, max_frames = 5, 5 * 1024
waveforms = torch.randn(batch_size, max_frames)
# Batch process
batch_logits, _ = model(waveforms)
# Par-sample process
for i in range(batch_size):
single_logit, _ = model(waveforms[i : i + 1])
batch_logit = batch_logits[i : i + 1]
# Convert to probability so that it's easier to interpretate the diff
single_prob = F.softmax(single_logit, dim=2)
batch_prob = F.softmax(batch_logit, dim=2)
# We allow max atol=0.005 -> 0.5%
self.assertEqual(single_prob, batch_prob, atol=0.005, rtol=0)
@factory_funcs_wavlm
def test_pretrain_batch_consistency(self, factory_func):
"""Results from single process and batched process should be reasonably close"""
self._test_batch_consistency(factory_func())
@factory_funcs_wavlm
def test_finetune_batch_consistency(self, factory_func):
"""Results from single process and batched process should be reasonably close"""
self._test_batch_consistency(factory_func(aux_num_out=32))
def _test_torchscript(self, model):
model.eval()
batch_size, num_frames = 3, 1024
waveforms = torch.randn(batch_size, num_frames)
# Compute results with original model
ref_out, ref_len = model(waveforms)
# Compute results with scripted model
scripted = torch_script(model)
hyp_out, hyp_len = scripted(waveforms)
self.assertEqual(hyp_out, ref_out)
self.assertEqual(hyp_len, ref_len)
@factory_funcs_wavlm
def test_pretrain_torchscript(self, factory_func):
"""WavLM model should be scriptable"""
self._test_torchscript(factory_func())
@factory_funcs_wavlm
def test_finetune_torchscript(self, factory_func):
"""WavLM model with a head should be scriptable"""
self._test_torchscript(factory_func(aux_num_out=32))
def _test_quantize_smoke_test(self, model):
model.eval()
batch_size, num_frames = 3, 1024
# Remove the weight normalization forward hook
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
quantized = tq.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
# A lazy way to check that Modules are different
assert str(quantized) != str(model), "Dynamic quantization did not modify the module."
waveforms = torch.randn(batch_size, num_frames)
_, _ = quantized(waveforms)
@factory_funcs_wavlm
@skipIfNoQengine
def test_quantize(self, factory_func):
"""WavLM should support basic quantization"""
self._test_quantize_smoke_test(factory_func(aux_num_out=32))
def _test_quantize_torchscript(self, model):
model.eval()
batch_size, num_frames = 3, 1024
# Remove the weight normalization forward hook
model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
quantized = tq.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
# A lazy way to check that Modules are different
assert str(quantized) != str(model), "Dynamic quantization did not modify the module."
waveforms = torch.randn(batch_size, num_frames)
ref_out, ref_len = quantized(waveforms)
# Script
scripted = torch_script(quantized)
hyp_out, hyp_len = scripted(waveforms)
self.assertEqual(hyp_out, ref_out)
self.assertEqual(hyp_len, ref_len)
@factory_funcs_wavlm
@skipIfNoQengine
def test_quantize_torchscript(self, factory_func):
"""Quantized WavLM model should be scriptable"""
self._test_quantize_torchscript(factory_func(aux_num_out=32))
def _compute_label_frame(audio_frame: int) -> int: def _compute_label_frame(audio_frame: int) -> int:
"""Compute number of frames in the label tensor based on """Compute number of frames in the label tensor based on
the number of frames in the audio tensor.""" the number of frames in the audio tensor."""
......
import os import os
import platform import platform
import sys import sys
import unittest
from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ProcessPoolExecutor
from typing import List, Tuple from typing import List, Tuple
from unittest import skipIf from unittest import skipIf
...@@ -94,8 +95,9 @@ class TestSoxEffectsDataset(TempDirMixin, PytorchTestCase): ...@@ -94,8 +95,9 @@ class TestSoxEffectsDataset(TempDirMixin, PytorchTestCase):
loader = torch.utils.data.DataLoader( loader = torch.utils.data.DataLoader(
dataset, dataset,
batch_size=32, batch_size=32,
num_workers=16, num_workers=4,
worker_init_fn=init_random_seed, worker_init_fn=init_random_seed,
multiprocessing_context=torch.multiprocessing.get_context("spawn"),
) )
for batch in loader: for batch in loader:
assert batch.shape == (32, 2, 2 * sample_rate) assert batch.shape == (32, 2, 2 * sample_rate)
...@@ -115,8 +117,9 @@ class TestSoxEffectsDataset(TempDirMixin, PytorchTestCase): ...@@ -115,8 +117,9 @@ class TestSoxEffectsDataset(TempDirMixin, PytorchTestCase):
loader = torch.utils.data.DataLoader( loader = torch.utils.data.DataLoader(
dataset, dataset,
batch_size=32, batch_size=32,
num_workers=16, num_workers=4,
worker_init_fn=init_random_seed, worker_init_fn=init_random_seed,
multiprocessing_context=torch.multiprocessing.get_context("spawn"),
) )
for batch in loader: for batch in loader:
assert batch.shape == (32, 2, 2 * sample_rate) assert batch.shape == (32, 2, 2 * sample_rate)
...@@ -131,6 +134,7 @@ def speed(path): ...@@ -131,6 +134,7 @@ def speed(path):
return torchaudio.sox_effects.apply_effects_tensor(wav, sample_rate, effects)[0] return torchaudio.sox_effects.apply_effects_tensor(wav, sample_rate, effects)[0]
@unittest.skipIf(True, "Skipping this test because condition is True")
@skipIfNoSox @skipIfNoSox
class TestProcessPoolExecutor(TempDirMixin, PytorchTestCase): class TestProcessPoolExecutor(TempDirMixin, PytorchTestCase):
backend = "sox_io" backend = "sox_io"
......
...@@ -54,24 +54,3 @@ class SmokeTest(TempDirMixin, TorchaudioTestCase): ...@@ -54,24 +54,3 @@ class SmokeTest(TempDirMixin, TorchaudioTestCase):
_found, _sr = sox_effects.apply_effects_file( _found, _sr = sox_effects.apply_effects_file(
input_path, effects, normalize=False, channels_first=channels_first input_path, effects, normalize=False, channels_first=channels_first
) )
@parameterized.expand(
load_params("sox_effect_test_args.jsonl"),
name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}',
)
def test_apply_effects_fileobj(self, args):
"""`apply_effects_file` should return identical data as sox command"""
dtype = "int32"
channels_first = True
effects = args["effects"]
num_channels = args.get("num_channels", 2)
input_sr = args.get("input_sample_rate", 8000)
input_path = self.get_temp_path("input.wav")
data = get_wav_data(dtype, num_channels, channels_first=channels_first)
save_wav(input_path, data, input_sr, channels_first=channels_first)
with open(input_path, "rb") as fileobj:
_found, _sr = sox_effects.apply_effects_file(
fileobj, effects, normalize=False, channels_first=channels_first
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment