Commit 5859923a authored by Joao Gomes's avatar Joao Gomes Committed by Facebook GitHub Bot
Browse files

Apply arc lint to pytorch audio (#2096)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2096

run: `arc lint --apply-patches --paths-cmd 'hg files -I "./**/*.py"'`

Reviewed By: mthrok

Differential Revision: D33297351

fbshipit-source-id: 7bf5956edf0717c5ca90219f72414ff4eeaf5aa8
parent 0e5913d5
from argparse import ArgumentParser
from pathlib import Path
from lightning_train import _get_model, _get_dataloader, sisdri_metric
import mir_eval
import torch
from lightning_train import _get_model, _get_dataloader, sisdri_metric
def _eval(model, data_loader, device):
......@@ -19,12 +19,9 @@ def _eval(model, data_loader, device):
mix = mix.repeat(1, src.shape[1], 1).cpu().detach().numpy()
sdr, sir, sar, _ = mir_eval.separation.bss_eval_sources(src[0], est[0])
sdr_mix, sir_mix, sar_mix, _ = mir_eval.separation.bss_eval_sources(src[0], mix[0])
results += torch.tensor([
sdr.mean() - sdr_mix.mean(),
sisdri,
sir.mean() - sir_mix.mean(),
sar.mean() - sar_mix.mean()
])
results += torch.tensor(
[sdr.mean() - sdr_mix.mean(), sisdri, sir.mean() - sir_mix.mean(), sar.mean() - sar_mix.mean()]
)
results /= len(data_loader)
print("SDR improvement: ", results[0].item())
print("Si-SDR improvement: ", results[1].item())
......@@ -63,28 +60,20 @@ def cli_main():
help="Sample rate of audio files in the given dataset. (default: 8000)",
)
parser.add_argument(
"--exp-dir",
default=Path("./exp"),
type=Path,
help="The directory to save checkpoints and logs."
)
parser.add_argument(
"--gpu-device",
default=-1,
type=int,
help="The gpu device for model inference. (default: -1)"
"--exp-dir", default=Path("./exp"), type=Path, help="The directory to save checkpoints and logs."
)
parser.add_argument("--gpu-device", default=-1, type=int, help="The gpu device for model inference. (default: -1)")
args = parser.parse_args()
model = _get_model(num_sources=2)
state_dict = torch.load(args.exp_dir / 'best_model.pth')
state_dict = torch.load(args.exp_dir / "best_model.pth")
model.load_state_dict(state_dict)
if args.gpu_device != -1:
device = torch.device('cuda:' + str(args.gpu_device))
device = torch.device("cuda:" + str(args.gpu_device))
else:
device = torch.device('cpu')
device = torch.device("cpu")
model = model.to(device)
......
#!/usr/bin/env python3
# pyre-strict
from pathlib import Path
from argparse import ArgumentParser
from pathlib import Path
from typing import (
Any,
Callable,
......@@ -34,10 +34,7 @@ class Batch(TypedDict):
def sisdri_metric(
estimate: torch.Tensor,
reference: torch.Tensor,
mix: torch.Tensor,
mask: torch.Tensor
estimate: torch.Tensor, reference: torch.Tensor, mix: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
"""Compute the improvement of scale-invariant SDR. (SI-SDRi).
......@@ -100,11 +97,7 @@ def sdri_metric(
return sdri.mean().item()
def si_sdr_loss(
estimate: torch.Tensor,
reference: torch.Tensor,
mask: torch.Tensor
) -> torch.Tensor:
def si_sdr_loss(estimate: torch.Tensor, reference: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Compute the Si-SDR loss.
Args:
......@@ -181,22 +174,16 @@ class ConvTasNetModule(LightningModule):
"""
return self.model(x)
def training_step(
self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any
) -> Dict[str, Any]:
def training_step(self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any) -> Dict[str, Any]:
return self._step(batch, batch_idx, "train")
def validation_step(
self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any
) -> Dict[str, Any]:
def validation_step(self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any) -> Dict[str, Any]:
"""
Operates on a single batch of data from the validation set.
"""
return self._step(batch, batch_idx, "val")
def test_step(
self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any
) -> Optional[Dict[str, Any]]:
def test_step(self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any) -> Optional[Dict[str, Any]]:
"""
Operates on a single batch of data from the test set.
"""
......@@ -222,11 +209,7 @@ class ConvTasNetModule(LightningModule):
lr_scheduler = self.lr_scheduler
if not lr_scheduler:
return self.optim
epoch_schedulers = {
'scheduler': lr_scheduler,
'monitor': 'Losses/val_loss',
'interval': 'epoch'
}
epoch_schedulers = {"scheduler": lr_scheduler, "monitor": "Losses/val_loss", "interval": "epoch"}
return [self.optim], [epoch_schedulers]
def _compute_metrics(
......@@ -305,11 +288,9 @@ def _get_dataloader(
train_dataset, valid_dataset, eval_dataset = dataset_utils.get_dataset(
dataset_type, root_dir, num_speakers, sample_rate, librimix_task, librimix_tr_split
)
train_collate_fn = dataset_utils.get_collate_fn(
dataset_type, mode='train', sample_rate=sample_rate, duration=3
)
train_collate_fn = dataset_utils.get_collate_fn(dataset_type, mode="train", sample_rate=sample_rate, duration=3)
test_collate_fn = dataset_utils.get_collate_fn(dataset_type, mode='test', sample_rate=sample_rate)
test_collate_fn = dataset_utils.get_collate_fn(dataset_type, mode="test", sample_rate=sample_rate)
train_loader = DataLoader(
train_dataset,
......@@ -367,10 +348,7 @@ def cli_main():
help="Sample rate of audio files in the given dataset. (default: 8000)",
)
parser.add_argument(
"--exp-dir",
default=Path("./exp"),
type=Path,
help="The directory to save checkpoints and logs."
"--exp-dir", default=Path("./exp"), type=Path, help="The directory to save checkpoints and logs."
)
parser.add_argument(
"--epochs",
......@@ -409,9 +387,7 @@ def cli_main():
model = _get_model(num_sources=args.num_speakers)
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=5
)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)
train_loader, valid_loader, eval_loader = _get_dataloader(
args.dataset,
args.root_dir,
......@@ -438,12 +414,7 @@ def cli_main():
)
checkpoint_dir = args.exp_dir / "checkpoints"
checkpoint = ModelCheckpoint(
checkpoint_dir,
monitor="Losses/val_loss",
mode="min",
save_top_k=5,
save_weights_only=True,
verbose=True
checkpoint_dir, monitor="Losses/val_loss", mode="min", save_top_k=5, save_weights_only=True, verbose=True
)
callbacks = [
checkpoint,
......
......@@ -15,13 +15,12 @@ number of training subprocesses (as operaiton mode 2). You can reduce the number
When launching the script as a worker process of a distributed training, you need to configure
the coordination of the workers.
"""
import sys
import logging
import argparse
import logging
import subprocess
import sys
import torch
from utils import dist_utils
_LG = dist_utils.getLogger(__name__)
......@@ -88,19 +87,13 @@ def _parse_args(args=None):
type=int,
help="Set random seed value. (default: None)",
)
parser.add_argument(
"rest", nargs=argparse.REMAINDER, help="Model-specific arguments."
)
parser.add_argument("rest", nargs=argparse.REMAINDER, help="Model-specific arguments.")
namespace = parser.parse_args(args)
if namespace.worker_id is None:
if namespace.device_id is not None:
raise ValueError(
"`--device-id` cannot be provided when runing as master process."
)
raise ValueError("`--device-id` cannot be provided when runing as master process.")
if namespace.num_workers > max_world_size:
raise ValueError(
"--num-workers ({num_workers}) cannot exceed {device_count}."
)
raise ValueError("--num-workers ({num_workers}) cannot exceed {device_count}.")
if namespace.rest[:1] == ["--"]:
namespace.rest = namespace.rest[1:]
return namespace
......@@ -120,7 +113,7 @@ def _main(cli_args):
world_size=args.num_workers,
rank=args.worker_id,
local_rank=args.device_id,
backend='nccl' if torch.cuda.is_available() else 'gloo',
backend="nccl" if torch.cuda.is_available() else "gloo",
init_method=args.sync_protocol,
)
if args.random_seed is not None:
......@@ -137,12 +130,7 @@ def _run_training_subprocesses(num_workers, original_args):
for i in range(num_workers):
worker_arg = ["--worker-id", f"{i}", "--num-workers", f"{num_workers}"]
device_arg = ["--device-id", f"{i}"] if torch.cuda.is_available() else []
command = (
[sys.executable, "-u", sys.argv[0]]
+ worker_arg
+ device_arg
+ original_args
)
command = [sys.executable, "-u", sys.argv[0]] + worker_arg + device_arg + original_args
_LG.info("Launching worker %s: `%s`", i, " ".join(command))
worker = subprocess.Popen(command)
workers.append(worker)
......@@ -163,9 +151,7 @@ def _run_training(args):
def _init_logger(rank=None, debug=False):
worker_fmt = "[master]" if rank is None else f"[worker {rank:2d}]"
message_fmt = (
"%(levelname)5s: %(funcName)10s: %(message)s" if debug else "%(message)s"
)
message_fmt = "%(levelname)5s: %(funcName)10s: %(message)s" if debug else "%(message)s"
logging.basicConfig(
level=logging.DEBUG if debug else logging.INFO,
format=f"%(asctime)s: {worker_fmt} {message_fmt}",
......
......@@ -4,4 +4,4 @@ from . import (
metrics,
)
__all__ = ['dataset', 'dist_utils', 'metrics']
__all__ = ["dataset", "dist_utils", "metrics"]
from . import utils, wsj0mix
__all__ = ['utils', 'wsj0mix']
__all__ = ["utils", "wsj0mix"]
from typing import List
from functools import partial
from collections import namedtuple
from functools import partial
from typing import List
from torchaudio.datasets import LibriMix
import torch
from torchaudio.datasets import LibriMix
from . import wsj0mix
......@@ -30,8 +30,8 @@ def _fix_num_frames(sample: wsj0mix.SampleType, target_num_frames: int, sample_r
src = torch.cat(sample[2], 0) # [num_sources, time]
num_channels, num_frames = src.shape
num_seconds = torch.div(num_frames, sample_rate, rounding_mode='floor')
target_seconds = torch.div(target_num_frames, sample_rate, rounding_mode='floor')
num_seconds = torch.div(num_frames, sample_rate, rounding_mode="floor")
target_seconds = torch.div(target_num_frames, sample_rate, rounding_mode="floor")
if num_frames >= target_num_frames:
if random_start and num_frames > target_num_frames:
start_frame = torch.randint(num_seconds - target_seconds + 1, [1]) * sample_rate
......@@ -81,7 +81,7 @@ def collate_fn_wsj0mix_test(samples: List[wsj0mix.SampleType], sample_rate):
def get_collate_fn(dataset_type, mode, sample_rate=None, duration=4):
assert mode in ["train", "test"]
if dataset_type in ["wsj0mix", "librimix"]:
if mode == 'train':
if mode == "train":
if sample_rate is None:
raise ValueError("sample_rate is not given.")
return partial(collate_fn_wsj0mix_train, sample_rate=sample_rate, duration=duration)
......
......@@ -2,9 +2,8 @@ from pathlib import Path
from typing import Union, Tuple, List
import torch
from torch.utils.data import Dataset
import torchaudio
from torch.utils.data import Dataset
SampleType = Tuple[int, torch.Tensor, List[torch.Tensor]]
......@@ -21,6 +20,7 @@ class WSJ0Mix(Dataset):
different sample rate, raises ``ValueError``.
audio_ext (str, optional): The extension of audio files to find. (default: ".wav")
"""
def __init__(
self,
root: Union[str, Path],
......@@ -51,9 +51,7 @@ class WSJ0Mix(Dataset):
for i, dir_ in enumerate(self.src_dirs):
src = self._load_audio(str(dir_ / filename))
if mixed.shape != src.shape:
raise ValueError(
f"Different waveform shapes. mixed: {mixed.shape}, src[{i}]: {src.shape}"
)
raise ValueError(f"Different waveform shapes. mixed: {mixed.shape}, src[{i}]: {src.shape}")
srcs.append(src)
return self.sample_rate, mixed, srcs
......
import os
import csv
import types
import logging
import os
import types
import torch
import torch.distributed as dist
......@@ -22,9 +22,7 @@ def getLogger(name):
_LG = getLogger(__name__)
def setup_distributed(
world_size, rank, local_rank, backend="nccl", init_method="env://"
):
def setup_distributed(world_size, rank, local_rank, backend="nccl", init_method="env://"):
"""Perform env setup and initialization for distributed training"""
if init_method == "env://":
_set_env_vars(world_size, rank, local_rank)
......
import math
from typing import Optional
from itertools import permutations
from typing import Optional
import torch
def sdr(
estimate: torch.Tensor,
reference: torch.Tensor,
mask: Optional[torch.Tensor] = None,
epsilon: float = 1e-8
estimate: torch.Tensor, reference: torch.Tensor, mask: Optional[torch.Tensor] = None, epsilon: float = 1e-8
) -> torch.Tensor:
"""Computes source-to-distortion ratio.
......@@ -86,11 +83,11 @@ class PIT(torch.nn.Module):
self.utility_func = utility_func
def forward(
self,
estimate: torch.Tensor,
reference: torch.Tensor,
mask: Optional[torch.Tensor] = None,
epsilon: float = 1e-8
self,
estimate: torch.Tensor,
reference: torch.Tensor,
mask: Optional[torch.Tensor] = None,
epsilon: float = 1e-8,
) -> torch.Tensor:
"""Compute utterance-level PIT Loss
......@@ -112,9 +109,7 @@ class PIT(torch.nn.Module):
batch_size, num_speakers = reference.shape[:2]
num_permute = math.factorial(num_speakers)
util_mat = torch.zeros(
batch_size, num_permute, dtype=estimate.dtype, device=estimate.device
)
util_mat = torch.zeros(batch_size, num_permute, dtype=estimate.dtype, device=estimate.device)
for i, idx in enumerate(permutations(range(num_speakers))):
util = self.utility_func(estimate, reference[:, idx, :], mask=mask, epsilon=epsilon)
util_mat[:, i] = util.mean(dim=1) # take the average over speaker dimension
......@@ -125,10 +120,8 @@ _sdr_pit = PIT(utility_func=sdr)
def sdr_pit(
estimate: torch.Tensor,
reference: torch.Tensor,
mask: Optional[torch.Tensor] = None,
epsilon: float = 1e-8):
estimate: torch.Tensor, reference: torch.Tensor, mask: Optional[torch.Tensor] = None, epsilon: float = 1e-8
):
"""Computes scale-invariant source-to-distortion ratio.
1. adjust both estimate and reference to have 0-mean
......@@ -164,11 +157,11 @@ def sdr_pit(
def sdri(
estimate: torch.Tensor,
reference: torch.Tensor,
mix: torch.Tensor,
mask: Optional[torch.Tensor] = None,
epsilon: float = 1e-8,
estimate: torch.Tensor,
reference: torch.Tensor,
mix: torch.Tensor,
mask: Optional[torch.Tensor] = None,
epsilon: float = 1e-8,
) -> torch.Tensor:
"""Compute the improvement of SDR (SDRi).
......
......@@ -93,9 +93,7 @@ class ASRTest(unittest.TestCase):
def test_transcribe_file(self):
task, generator, models, sp, tgt_dict = setup_asr(self.args, self.logger)
_, transcription = transcribe_file(
self.args, task, generator, models, sp, tgt_dict
)
_, transcription = transcribe_file(self.args, task, generator, models, sp, tgt_dict)
expected_transcription = [["THE QUICK BROWN FOX JUMPS OVER THE LAZY DOG"]]
self.assertEqual(transcription, expected_transcription, msg=str(transcription))
......
......@@ -32,9 +32,9 @@ print(torchaudio.__version__)
import math
import os
import requests
import matplotlib.pyplot as plt
import requests
from IPython.display import Audio, display
......@@ -164,7 +164,7 @@ def get_rir_sample(*, resample=None, processed=False):
rir_raw, sample_rate = _get_sample(SAMPLE_RIR_PATH, resample=resample)
if not processed:
return rir_raw, sample_rate
rir = rir_raw[:, int(sample_rate * 1.01): int(sample_rate * 1.3)]
rir = rir_raw[:, int(sample_rate * 1.01) : int(sample_rate * 1.3)]
rir = rir / torch.norm(rir, p=2)
rir = torch.flip(rir, [1])
return rir, sample_rate
......@@ -225,9 +225,7 @@ effects = [
]
# Apply effects
waveform2, sample_rate2 = torchaudio.sox_effects.apply_effects_tensor(
waveform1, sample_rate1, effects
)
waveform2, sample_rate2 = torchaudio.sox_effects.apply_effects_tensor(waveform1, sample_rate1, effects)
print_stats(waveform1, sample_rate=sample_rate1, src="Original")
print_stats(waveform2, sample_rate=sample_rate2, src="Effects Applied")
......@@ -291,7 +289,7 @@ play_audio(rir_raw, sample_rate)
# the signal power, then flip along the time axis.
#
rir = rir_raw[:, int(sample_rate * 1.01): int(sample_rate * 1.3)]
rir = rir_raw[:, int(sample_rate * 1.01) : int(sample_rate * 1.3)]
rir = rir / torch.norm(rir, p=2)
rir = torch.flip(rir, [1])
......
......@@ -33,10 +33,10 @@ print(torchaudio.__version__)
# -------------------------------------------------------------------------------
import os
import requests
import librosa
import matplotlib.pyplot as plt
import requests
_SAMPLE_DIR = "_assets"
......@@ -125,17 +125,13 @@ stretch = T.TimeStretch()
rate = 1.2
spec_ = stretch(spec, rate)
plot_spectrogram(
torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304
)
plot_spectrogram(torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304)
plot_spectrogram(torch.abs(spec[0]), title="Original", aspect="equal", xmax=304)
rate = 0.9
spec_ = stretch(spec, rate)
plot_spectrogram(
torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304
)
plot_spectrogram(torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304)
######################################################################
# TimeMasking
......
......@@ -51,10 +51,10 @@ print(torchaudio.__version__)
# -------------------------------------------------------------------------------
import os
import requests
import librosa
import matplotlib.pyplot as plt
import requests
from IPython.display import Audio, display
......@@ -199,9 +199,7 @@ def plot_kaldi_pitch(waveform, sample_rate, pitch, nfcc):
axis2 = axis.twinx()
time_axis = torch.linspace(0, end_time, nfcc.shape[1])
ln2 = axis2.plot(
time_axis, nfcc[0], linewidth=2, label="NFCC", color="blue", linestyle="--"
)
ln2 = axis2.plot(time_axis, nfcc[0], linewidth=2, label="NFCC", color="blue", linestyle="--")
lns = ln1 + ln2
labels = [l.get_label() for l in lns]
......
......@@ -32,13 +32,13 @@ print(torchaudio.__version__)
import io
import os
import requests
import tarfile
import boto3
import matplotlib.pyplot as plt
import requests
from botocore import UNSIGNED
from botocore.config import Config
import matplotlib.pyplot as plt
from IPython.display import Audio, display
......@@ -348,14 +348,12 @@ frame_offset, num_frames = 16000, 16000 # Fetch and decode the 1 - 2 seconds
print("Fetching all the data...")
with requests.get(SAMPLE_WAV_SPEECH_URL, stream=True) as response:
waveform1, sample_rate1 = torchaudio.load(response.raw)
waveform1 = waveform1[:, frame_offset: frame_offset + num_frames]
waveform1 = waveform1[:, frame_offset : frame_offset + num_frames]
print(f" - Fetched {response.raw.tell()} bytes")
print("Fetching until the requested frames are available...")
with requests.get(SAMPLE_WAV_SPEECH_URL, stream=True) as response:
waveform2, sample_rate2 = torchaudio.load(
response.raw, frame_offset=frame_offset, num_frames=num_frames
)
waveform2, sample_rate2 = torchaudio.load(response.raw, frame_offset=frame_offset, num_frames=num_frames)
print(f" - Fetched {response.raw.tell()} bytes")
print("Checking the resulting waveform ... ", end="")
......
......@@ -38,8 +38,8 @@ import time
import librosa
import matplotlib.pyplot as plt
from IPython.display import Audio, display
import pandas as pd
from IPython.display import Audio, display
DEFAULT_OFFSET = 201
......@@ -56,9 +56,7 @@ def _get_log_freq(sample_rate, max_sweep_rate, offset):
"""
start, stop = math.log(offset), math.log(offset + max_sweep_rate // 2)
return (
torch.exp(torch.linspace(start, stop, sample_rate, dtype=torch.double)) - offset
)
return torch.exp(torch.linspace(start, stop, sample_rate, dtype=torch.double)) - offset
def _get_inverse_log_freq(freq, sample_rate, offset):
......@@ -192,9 +190,7 @@ def benchmark_resample(
waveform_np = waveform.squeeze().numpy()
begin = time.time()
for _ in range(iters):
librosa.resample(
waveform_np, sample_rate, resample_rate, res_type=librosa_type
)
librosa.resample(waveform_np, sample_rate, resample_rate, res_type=librosa_type)
elapsed = time.time() - begin
return elapsed / iters
......@@ -264,14 +260,10 @@ play_audio(waveform, sample_rate)
sample_rate = 48000
resample_rate = 32000
resampled_waveform = F.resample(
waveform, sample_rate, resample_rate, lowpass_filter_width=6
)
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=6)
plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=6")
resampled_waveform = F.resample(
waveform, sample_rate, resample_rate, lowpass_filter_width=128
)
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=128)
plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=128")
......@@ -315,14 +307,10 @@ plot_sweep(resampled_waveform, resample_rate, title="rolloff=0.8")
sample_rate = 48000
resample_rate = 32000
resampled_waveform = F.resample(
waveform, sample_rate, resample_rate, resampling_method="sinc_interpolation"
)
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="sinc_interpolation")
plot_sweep(resampled_waveform, resample_rate, title="Hann Window Default")
resampled_waveform = F.resample(
waveform, sample_rate, resample_rate, resampling_method="kaiser_window"
)
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="kaiser_window")
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Default")
......@@ -351,13 +339,9 @@ resampled_waveform = F.resample(
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Best (torchaudio)")
librosa_resampled_waveform = torch.from_numpy(
librosa.resample(
waveform.squeeze().numpy(), sample_rate, resample_rate, res_type="kaiser_best"
)
librosa.resample(waveform.squeeze().numpy(), sample_rate, resample_rate, res_type="kaiser_best")
).unsqueeze(0)
plot_sweep(
librosa_resampled_waveform, resample_rate, title="Kaiser Window Best (librosa)"
)
plot_sweep(librosa_resampled_waveform, resample_rate, title="Kaiser Window Best (librosa)")
mse = torch.square(resampled_waveform - librosa_resampled_waveform).mean().item()
print("torchaudio and librosa kaiser best MSE:", mse)
......@@ -372,18 +356,12 @@ resampled_waveform = F.resample(
resampling_method="kaiser_window",
beta=8.555504641634386,
)
plot_specgram(
resampled_waveform, resample_rate, title="Kaiser Window Fast (torchaudio)"
)
plot_specgram(resampled_waveform, resample_rate, title="Kaiser Window Fast (torchaudio)")
librosa_resampled_waveform = torch.from_numpy(
librosa.resample(
waveform.squeeze().numpy(), sample_rate, resample_rate, res_type="kaiser_fast"
)
librosa.resample(waveform.squeeze().numpy(), sample_rate, resample_rate, res_type="kaiser_fast")
).unsqueeze(0)
plot_sweep(
librosa_resampled_waveform, resample_rate, title="Kaiser Window Fast (librosa)"
)
plot_sweep(librosa_resampled_waveform, resample_rate, title="Kaiser Window Fast (librosa)")
mse = torch.square(resampled_waveform - librosa_resampled_waveform).mean().item()
print("torchaudio and librosa kaiser fast MSE:", mse)
......@@ -426,29 +404,19 @@ for label in configs:
waveform = get_sine_sweep(sample_rate)
# sinc 64 zero-crossings
f_time = benchmark_resample(
"functional", waveform, sample_rate, resample_rate, lowpass_filter_width=64
)
t_time = benchmark_resample(
"transforms", waveform, sample_rate, resample_rate, lowpass_filter_width=64
)
f_time = benchmark_resample("functional", waveform, sample_rate, resample_rate, lowpass_filter_width=64)
t_time = benchmark_resample("transforms", waveform, sample_rate, resample_rate, lowpass_filter_width=64)
times.append([None, 1000 * f_time, 1000 * t_time])
rows.append("sinc (width 64)")
# sinc 6 zero-crossings
f_time = benchmark_resample(
"functional", waveform, sample_rate, resample_rate, lowpass_filter_width=16
)
t_time = benchmark_resample(
"transforms", waveform, sample_rate, resample_rate, lowpass_filter_width=16
)
f_time = benchmark_resample("functional", waveform, sample_rate, resample_rate, lowpass_filter_width=16)
t_time = benchmark_resample("transforms", waveform, sample_rate, resample_rate, lowpass_filter_width=16)
times.append([None, 1000 * f_time, 1000 * t_time])
rows.append("sinc (width 16)")
# kaiser best
lib_time = benchmark_resample(
"librosa", waveform, sample_rate, resample_rate, librosa_type="kaiser_best"
)
lib_time = benchmark_resample("librosa", waveform, sample_rate, resample_rate, librosa_type="kaiser_best")
f_time = benchmark_resample(
"functional",
waveform,
......@@ -473,9 +441,7 @@ for label in configs:
rows.append("kaiser_best")
# kaiser fast
lib_time = benchmark_resample(
"librosa", waveform, sample_rate, resample_rate, librosa_type="kaiser_fast"
)
lib_time = benchmark_resample("librosa", waveform, sample_rate, resample_rate, librosa_type="kaiser_fast")
f_time = benchmark_resample(
"functional",
waveform,
......@@ -499,8 +465,6 @@ for label in configs:
times.append([1000 * lib_time, 1000 * f_time, 1000 * t_time])
rows.append("kaiser_fast")
df = pd.DataFrame(
times, columns=["librosa", "functional", "transforms"], index=rows
)
df = pd.DataFrame(times, columns=["librosa", "functional", "transforms"], index=rows)
df.columns = pd.MultiIndex.from_product([[f"{label} time (ms)"], df.columns])
display(df.round(2))
......@@ -40,12 +40,12 @@ Recognition <https://arxiv.org/abs/2007.09127>`__.
import os
from dataclasses import dataclass
import torch
import torchaudio
import requests
import IPython
import matplotlib
import matplotlib.pyplot as plt
import IPython
import requests
import torch
import torchaudio
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
......@@ -325,7 +325,7 @@ def plot_trellis_with_segments(trellis, segments, transcript):
trellis_with_path = trellis.clone()
for i, seg in enumerate(segments):
if seg.label != "|":
trellis_with_path[seg.start + 1: seg.end + 1, i + 1] = float("nan")
trellis_with_path[seg.start + 1 : seg.end + 1, i + 1] = float("nan")
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5))
ax1.set_title("Path, label and probability for each label")
......@@ -383,12 +383,8 @@ def merge_words(segments, separator="|"):
if i1 != i2:
segs = segments[i1:i2]
word = "".join([seg.label for seg in segs])
score = sum(seg.score * seg.length for seg in segs) / sum(
seg.length for seg in segs
)
words.append(
Segment(word, segments[i1].start, segments[i2 - 1].end, score)
)
score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score))
i1 = i2 + 1
i2 = i1
else:
......@@ -408,7 +404,7 @@ def plot_alignments(trellis, segments, word_segments, waveform):
trellis_with_path = trellis.clone()
for i, seg in enumerate(segments):
if seg.label != "|":
trellis_with_path[seg.start + 1: seg.end + 1, i + 1] = float("nan")
trellis_with_path[seg.start + 1 : seg.end + 1, i + 1] = float("nan")
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5))
......@@ -464,9 +460,7 @@ def display_segment(i):
x1 = int(ratio * word.end)
filename = f"_assets/{i}_{word.label}.wav"
torchaudio.save(filename, waveform[:, x0:x1], bundle.sample_rate)
print(
f"{word.label} ({word.score:.2f}): {x0 / bundle.sample_rate:.3f} - {x1 / bundle.sample_rate:.3f} sec"
)
print(f"{word.label} ({word.score:.2f}): {x0 / bundle.sample_rate:.3f} - {x1 / bundle.sample_rate:.3f} sec")
return IPython.display.Audio(filename)
......
......@@ -44,10 +44,11 @@ MVDR with torchaudio
#
import os
import IPython.display as ipd
import requests
import torch
import torchaudio
import IPython.display as ipd
torch.random.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
......@@ -167,9 +168,7 @@ for solution in ["ref_channel", "stv_evd", "stv_power"]:
results_single = {}
for solution in ["ref_channel", "stv_evd", "stv_power"]:
mvdr = torchaudio.transforms.MVDR(
ref_channel=0, solution=solution, multi_mask=False
)
mvdr = torchaudio.transforms.MVDR(ref_channel=0, solution=solution, multi_mask=False)
stft_est = mvdr(spec_mix, irm_speech[0], irm_noise[0])
est = istft(stft_est, length=mix.shape[-1])
results_single[solution] = est
......@@ -211,9 +210,7 @@ def si_sdr(estimate, reference, epsilon=1e-8):
#
for solution in results_single:
print(
solution + ": ", si_sdr(results_single[solution][None, ...], reverb_clean[0:1])
)
print(solution + ": ", si_sdr(results_single[solution][None, ...], reverb_clean[0:1]))
######################################################################
# Multi-channel mask results
......@@ -221,9 +218,7 @@ for solution in results_single:
#
for solution in results_multi:
print(
solution + ": ", si_sdr(results_multi[solution][None, ...], reverb_clean[0:1])
)
print(solution + ": ", si_sdr(results_multi[solution][None, ...], reverb_clean[0:1]))
######################################################################
# Original audio
......
......@@ -41,12 +41,12 @@ pre-trained models from wav2vec 2.0
import os
import torch
import torchaudio
import requests
import IPython
import matplotlib
import matplotlib.pyplot as plt
import IPython
import requests
import torch
import torchaudio
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
......
......@@ -7,6 +7,10 @@ Text-to-Speech with Tacotron2
"""
import IPython
import matplotlib
import matplotlib.pyplot as plt
######################################################################
# Overview
# --------
......@@ -58,10 +62,6 @@ Text-to-Speech with Tacotron2
import torch
import torchaudio
import matplotlib
import matplotlib.pyplot as plt
import IPython
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
......@@ -271,9 +271,7 @@ fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9))
ax1.imshow(spec[0].cpu().detach())
ax2.plot(waveforms[0].cpu().detach())
torchaudio.save(
"_assets/output_wavernn.wav", waveforms[0:1].cpu(), sample_rate=vocoder.sample_rate
)
torchaudio.save("_assets/output_wavernn.wav", waveforms[0:1].cpu(), sample_rate=vocoder.sample_rate)
IPython.display.Audio("_assets/output_wavernn.wav")
......@@ -332,9 +330,7 @@ checkpoint = torch.hub.load_state_dict_from_url(
progress=False,
map_location=device,
)
state_dict = {
key.replace("module.", ""): value for key, value in checkpoint["state_dict"].items()
}
state_dict = {key.replace("module.", ""): value for key, value in checkpoint["state_dict"].items()}
waveglow.load_state_dict(state_dict)
waveglow = waveglow.remove_weightnorm(waveglow)
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment