Commit 5859923a authored by Joao Gomes's avatar Joao Gomes Committed by Facebook GitHub Bot
Browse files

Apply arc lint to pytorch audio (#2096)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/2096

run: `arc lint --apply-patches --paths-cmd 'hg files -I "./**/*.py"'`

Reviewed By: mthrok

Differential Revision: D33297351

fbshipit-source-id: 7bf5956edf0717c5ca90219f72414ff4eeaf5aa8
parent 0e5913d5
from argparse import ArgumentParser from argparse import ArgumentParser
from pathlib import Path from pathlib import Path
from lightning_train import _get_model, _get_dataloader, sisdri_metric
import mir_eval import mir_eval
import torch import torch
from lightning_train import _get_model, _get_dataloader, sisdri_metric
def _eval(model, data_loader, device): def _eval(model, data_loader, device):
...@@ -19,12 +19,9 @@ def _eval(model, data_loader, device): ...@@ -19,12 +19,9 @@ def _eval(model, data_loader, device):
mix = mix.repeat(1, src.shape[1], 1).cpu().detach().numpy() mix = mix.repeat(1, src.shape[1], 1).cpu().detach().numpy()
sdr, sir, sar, _ = mir_eval.separation.bss_eval_sources(src[0], est[0]) sdr, sir, sar, _ = mir_eval.separation.bss_eval_sources(src[0], est[0])
sdr_mix, sir_mix, sar_mix, _ = mir_eval.separation.bss_eval_sources(src[0], mix[0]) sdr_mix, sir_mix, sar_mix, _ = mir_eval.separation.bss_eval_sources(src[0], mix[0])
results += torch.tensor([ results += torch.tensor(
sdr.mean() - sdr_mix.mean(), [sdr.mean() - sdr_mix.mean(), sisdri, sir.mean() - sir_mix.mean(), sar.mean() - sar_mix.mean()]
sisdri, )
sir.mean() - sir_mix.mean(),
sar.mean() - sar_mix.mean()
])
results /= len(data_loader) results /= len(data_loader)
print("SDR improvement: ", results[0].item()) print("SDR improvement: ", results[0].item())
print("Si-SDR improvement: ", results[1].item()) print("Si-SDR improvement: ", results[1].item())
...@@ -63,28 +60,20 @@ def cli_main(): ...@@ -63,28 +60,20 @@ def cli_main():
help="Sample rate of audio files in the given dataset. (default: 8000)", help="Sample rate of audio files in the given dataset. (default: 8000)",
) )
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir", default=Path("./exp"), type=Path, help="The directory to save checkpoints and logs."
default=Path("./exp"),
type=Path,
help="The directory to save checkpoints and logs."
)
parser.add_argument(
"--gpu-device",
default=-1,
type=int,
help="The gpu device for model inference. (default: -1)"
) )
parser.add_argument("--gpu-device", default=-1, type=int, help="The gpu device for model inference. (default: -1)")
args = parser.parse_args() args = parser.parse_args()
model = _get_model(num_sources=2) model = _get_model(num_sources=2)
state_dict = torch.load(args.exp_dir / 'best_model.pth') state_dict = torch.load(args.exp_dir / "best_model.pth")
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
if args.gpu_device != -1: if args.gpu_device != -1:
device = torch.device('cuda:' + str(args.gpu_device)) device = torch.device("cuda:" + str(args.gpu_device))
else: else:
device = torch.device('cpu') device = torch.device("cpu")
model = model.to(device) model = model.to(device)
......
#!/usr/bin/env python3 #!/usr/bin/env python3
# pyre-strict # pyre-strict
from pathlib import Path
from argparse import ArgumentParser from argparse import ArgumentParser
from pathlib import Path
from typing import ( from typing import (
Any, Any,
Callable, Callable,
...@@ -34,10 +34,7 @@ class Batch(TypedDict): ...@@ -34,10 +34,7 @@ class Batch(TypedDict):
def sisdri_metric( def sisdri_metric(
estimate: torch.Tensor, estimate: torch.Tensor, reference: torch.Tensor, mix: torch.Tensor, mask: torch.Tensor
reference: torch.Tensor,
mix: torch.Tensor,
mask: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
"""Compute the improvement of scale-invariant SDR. (SI-SDRi). """Compute the improvement of scale-invariant SDR. (SI-SDRi).
...@@ -100,11 +97,7 @@ def sdri_metric( ...@@ -100,11 +97,7 @@ def sdri_metric(
return sdri.mean().item() return sdri.mean().item()
def si_sdr_loss( def si_sdr_loss(estimate: torch.Tensor, reference: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
estimate: torch.Tensor,
reference: torch.Tensor,
mask: torch.Tensor
) -> torch.Tensor:
"""Compute the Si-SDR loss. """Compute the Si-SDR loss.
Args: Args:
...@@ -181,22 +174,16 @@ class ConvTasNetModule(LightningModule): ...@@ -181,22 +174,16 @@ class ConvTasNetModule(LightningModule):
""" """
return self.model(x) return self.model(x)
def training_step( def training_step(self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any) -> Dict[str, Any]:
self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any
) -> Dict[str, Any]:
return self._step(batch, batch_idx, "train") return self._step(batch, batch_idx, "train")
def validation_step( def validation_step(self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any) -> Dict[str, Any]:
self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any
) -> Dict[str, Any]:
""" """
Operates on a single batch of data from the validation set. Operates on a single batch of data from the validation set.
""" """
return self._step(batch, batch_idx, "val") return self._step(batch, batch_idx, "val")
def test_step( def test_step(self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any) -> Optional[Dict[str, Any]]:
self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any
) -> Optional[Dict[str, Any]]:
""" """
Operates on a single batch of data from the test set. Operates on a single batch of data from the test set.
""" """
...@@ -222,11 +209,7 @@ class ConvTasNetModule(LightningModule): ...@@ -222,11 +209,7 @@ class ConvTasNetModule(LightningModule):
lr_scheduler = self.lr_scheduler lr_scheduler = self.lr_scheduler
if not lr_scheduler: if not lr_scheduler:
return self.optim return self.optim
epoch_schedulers = { epoch_schedulers = {"scheduler": lr_scheduler, "monitor": "Losses/val_loss", "interval": "epoch"}
'scheduler': lr_scheduler,
'monitor': 'Losses/val_loss',
'interval': 'epoch'
}
return [self.optim], [epoch_schedulers] return [self.optim], [epoch_schedulers]
def _compute_metrics( def _compute_metrics(
...@@ -305,11 +288,9 @@ def _get_dataloader( ...@@ -305,11 +288,9 @@ def _get_dataloader(
train_dataset, valid_dataset, eval_dataset = dataset_utils.get_dataset( train_dataset, valid_dataset, eval_dataset = dataset_utils.get_dataset(
dataset_type, root_dir, num_speakers, sample_rate, librimix_task, librimix_tr_split dataset_type, root_dir, num_speakers, sample_rate, librimix_task, librimix_tr_split
) )
train_collate_fn = dataset_utils.get_collate_fn( train_collate_fn = dataset_utils.get_collate_fn(dataset_type, mode="train", sample_rate=sample_rate, duration=3)
dataset_type, mode='train', sample_rate=sample_rate, duration=3
)
test_collate_fn = dataset_utils.get_collate_fn(dataset_type, mode='test', sample_rate=sample_rate) test_collate_fn = dataset_utils.get_collate_fn(dataset_type, mode="test", sample_rate=sample_rate)
train_loader = DataLoader( train_loader = DataLoader(
train_dataset, train_dataset,
...@@ -367,10 +348,7 @@ def cli_main(): ...@@ -367,10 +348,7 @@ def cli_main():
help="Sample rate of audio files in the given dataset. (default: 8000)", help="Sample rate of audio files in the given dataset. (default: 8000)",
) )
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir", default=Path("./exp"), type=Path, help="The directory to save checkpoints and logs."
default=Path("./exp"),
type=Path,
help="The directory to save checkpoints and logs."
) )
parser.add_argument( parser.add_argument(
"--epochs", "--epochs",
...@@ -409,9 +387,7 @@ def cli_main(): ...@@ -409,9 +387,7 @@ def cli_main():
model = _get_model(num_sources=args.num_speakers) model = _get_model(num_sources=args.num_speakers)
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)
optimizer, mode="min", factor=0.5, patience=5
)
train_loader, valid_loader, eval_loader = _get_dataloader( train_loader, valid_loader, eval_loader = _get_dataloader(
args.dataset, args.dataset,
args.root_dir, args.root_dir,
...@@ -438,12 +414,7 @@ def cli_main(): ...@@ -438,12 +414,7 @@ def cli_main():
) )
checkpoint_dir = args.exp_dir / "checkpoints" checkpoint_dir = args.exp_dir / "checkpoints"
checkpoint = ModelCheckpoint( checkpoint = ModelCheckpoint(
checkpoint_dir, checkpoint_dir, monitor="Losses/val_loss", mode="min", save_top_k=5, save_weights_only=True, verbose=True
monitor="Losses/val_loss",
mode="min",
save_top_k=5,
save_weights_only=True,
verbose=True
) )
callbacks = [ callbacks = [
checkpoint, checkpoint,
......
...@@ -15,13 +15,12 @@ number of training subprocesses (as operaiton mode 2). You can reduce the number ...@@ -15,13 +15,12 @@ number of training subprocesses (as operaiton mode 2). You can reduce the number
When launching the script as a worker process of a distributed training, you need to configure When launching the script as a worker process of a distributed training, you need to configure
the coordination of the workers. the coordination of the workers.
""" """
import sys
import logging
import argparse import argparse
import logging
import subprocess import subprocess
import sys
import torch import torch
from utils import dist_utils from utils import dist_utils
_LG = dist_utils.getLogger(__name__) _LG = dist_utils.getLogger(__name__)
...@@ -88,19 +87,13 @@ def _parse_args(args=None): ...@@ -88,19 +87,13 @@ def _parse_args(args=None):
type=int, type=int,
help="Set random seed value. (default: None)", help="Set random seed value. (default: None)",
) )
parser.add_argument( parser.add_argument("rest", nargs=argparse.REMAINDER, help="Model-specific arguments.")
"rest", nargs=argparse.REMAINDER, help="Model-specific arguments."
)
namespace = parser.parse_args(args) namespace = parser.parse_args(args)
if namespace.worker_id is None: if namespace.worker_id is None:
if namespace.device_id is not None: if namespace.device_id is not None:
raise ValueError( raise ValueError("`--device-id` cannot be provided when runing as master process.")
"`--device-id` cannot be provided when runing as master process."
)
if namespace.num_workers > max_world_size: if namespace.num_workers > max_world_size:
raise ValueError( raise ValueError("--num-workers ({num_workers}) cannot exceed {device_count}.")
"--num-workers ({num_workers}) cannot exceed {device_count}."
)
if namespace.rest[:1] == ["--"]: if namespace.rest[:1] == ["--"]:
namespace.rest = namespace.rest[1:] namespace.rest = namespace.rest[1:]
return namespace return namespace
...@@ -120,7 +113,7 @@ def _main(cli_args): ...@@ -120,7 +113,7 @@ def _main(cli_args):
world_size=args.num_workers, world_size=args.num_workers,
rank=args.worker_id, rank=args.worker_id,
local_rank=args.device_id, local_rank=args.device_id,
backend='nccl' if torch.cuda.is_available() else 'gloo', backend="nccl" if torch.cuda.is_available() else "gloo",
init_method=args.sync_protocol, init_method=args.sync_protocol,
) )
if args.random_seed is not None: if args.random_seed is not None:
...@@ -137,12 +130,7 @@ def _run_training_subprocesses(num_workers, original_args): ...@@ -137,12 +130,7 @@ def _run_training_subprocesses(num_workers, original_args):
for i in range(num_workers): for i in range(num_workers):
worker_arg = ["--worker-id", f"{i}", "--num-workers", f"{num_workers}"] worker_arg = ["--worker-id", f"{i}", "--num-workers", f"{num_workers}"]
device_arg = ["--device-id", f"{i}"] if torch.cuda.is_available() else [] device_arg = ["--device-id", f"{i}"] if torch.cuda.is_available() else []
command = ( command = [sys.executable, "-u", sys.argv[0]] + worker_arg + device_arg + original_args
[sys.executable, "-u", sys.argv[0]]
+ worker_arg
+ device_arg
+ original_args
)
_LG.info("Launching worker %s: `%s`", i, " ".join(command)) _LG.info("Launching worker %s: `%s`", i, " ".join(command))
worker = subprocess.Popen(command) worker = subprocess.Popen(command)
workers.append(worker) workers.append(worker)
...@@ -163,9 +151,7 @@ def _run_training(args): ...@@ -163,9 +151,7 @@ def _run_training(args):
def _init_logger(rank=None, debug=False): def _init_logger(rank=None, debug=False):
worker_fmt = "[master]" if rank is None else f"[worker {rank:2d}]" worker_fmt = "[master]" if rank is None else f"[worker {rank:2d}]"
message_fmt = ( message_fmt = "%(levelname)5s: %(funcName)10s: %(message)s" if debug else "%(message)s"
"%(levelname)5s: %(funcName)10s: %(message)s" if debug else "%(message)s"
)
logging.basicConfig( logging.basicConfig(
level=logging.DEBUG if debug else logging.INFO, level=logging.DEBUG if debug else logging.INFO,
format=f"%(asctime)s: {worker_fmt} {message_fmt}", format=f"%(asctime)s: {worker_fmt} {message_fmt}",
......
...@@ -4,4 +4,4 @@ from . import ( ...@@ -4,4 +4,4 @@ from . import (
metrics, metrics,
) )
__all__ = ['dataset', 'dist_utils', 'metrics'] __all__ = ["dataset", "dist_utils", "metrics"]
from . import utils, wsj0mix from . import utils, wsj0mix
__all__ = ['utils', 'wsj0mix'] __all__ = ["utils", "wsj0mix"]
from typing import List
from functools import partial
from collections import namedtuple from collections import namedtuple
from functools import partial
from typing import List
from torchaudio.datasets import LibriMix
import torch import torch
from torchaudio.datasets import LibriMix
from . import wsj0mix from . import wsj0mix
...@@ -30,8 +30,8 @@ def _fix_num_frames(sample: wsj0mix.SampleType, target_num_frames: int, sample_r ...@@ -30,8 +30,8 @@ def _fix_num_frames(sample: wsj0mix.SampleType, target_num_frames: int, sample_r
src = torch.cat(sample[2], 0) # [num_sources, time] src = torch.cat(sample[2], 0) # [num_sources, time]
num_channels, num_frames = src.shape num_channels, num_frames = src.shape
num_seconds = torch.div(num_frames, sample_rate, rounding_mode='floor') num_seconds = torch.div(num_frames, sample_rate, rounding_mode="floor")
target_seconds = torch.div(target_num_frames, sample_rate, rounding_mode='floor') target_seconds = torch.div(target_num_frames, sample_rate, rounding_mode="floor")
if num_frames >= target_num_frames: if num_frames >= target_num_frames:
if random_start and num_frames > target_num_frames: if random_start and num_frames > target_num_frames:
start_frame = torch.randint(num_seconds - target_seconds + 1, [1]) * sample_rate start_frame = torch.randint(num_seconds - target_seconds + 1, [1]) * sample_rate
...@@ -81,7 +81,7 @@ def collate_fn_wsj0mix_test(samples: List[wsj0mix.SampleType], sample_rate): ...@@ -81,7 +81,7 @@ def collate_fn_wsj0mix_test(samples: List[wsj0mix.SampleType], sample_rate):
def get_collate_fn(dataset_type, mode, sample_rate=None, duration=4): def get_collate_fn(dataset_type, mode, sample_rate=None, duration=4):
assert mode in ["train", "test"] assert mode in ["train", "test"]
if dataset_type in ["wsj0mix", "librimix"]: if dataset_type in ["wsj0mix", "librimix"]:
if mode == 'train': if mode == "train":
if sample_rate is None: if sample_rate is None:
raise ValueError("sample_rate is not given.") raise ValueError("sample_rate is not given.")
return partial(collate_fn_wsj0mix_train, sample_rate=sample_rate, duration=duration) return partial(collate_fn_wsj0mix_train, sample_rate=sample_rate, duration=duration)
......
...@@ -2,9 +2,8 @@ from pathlib import Path ...@@ -2,9 +2,8 @@ from pathlib import Path
from typing import Union, Tuple, List from typing import Union, Tuple, List
import torch import torch
from torch.utils.data import Dataset
import torchaudio import torchaudio
from torch.utils.data import Dataset
SampleType = Tuple[int, torch.Tensor, List[torch.Tensor]] SampleType = Tuple[int, torch.Tensor, List[torch.Tensor]]
...@@ -21,6 +20,7 @@ class WSJ0Mix(Dataset): ...@@ -21,6 +20,7 @@ class WSJ0Mix(Dataset):
different sample rate, raises ``ValueError``. different sample rate, raises ``ValueError``.
audio_ext (str, optional): The extension of audio files to find. (default: ".wav") audio_ext (str, optional): The extension of audio files to find. (default: ".wav")
""" """
def __init__( def __init__(
self, self,
root: Union[str, Path], root: Union[str, Path],
...@@ -51,9 +51,7 @@ class WSJ0Mix(Dataset): ...@@ -51,9 +51,7 @@ class WSJ0Mix(Dataset):
for i, dir_ in enumerate(self.src_dirs): for i, dir_ in enumerate(self.src_dirs):
src = self._load_audio(str(dir_ / filename)) src = self._load_audio(str(dir_ / filename))
if mixed.shape != src.shape: if mixed.shape != src.shape:
raise ValueError( raise ValueError(f"Different waveform shapes. mixed: {mixed.shape}, src[{i}]: {src.shape}")
f"Different waveform shapes. mixed: {mixed.shape}, src[{i}]: {src.shape}"
)
srcs.append(src) srcs.append(src)
return self.sample_rate, mixed, srcs return self.sample_rate, mixed, srcs
......
import os
import csv import csv
import types
import logging import logging
import os
import types
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -22,9 +22,7 @@ def getLogger(name): ...@@ -22,9 +22,7 @@ def getLogger(name):
_LG = getLogger(__name__) _LG = getLogger(__name__)
def setup_distributed( def setup_distributed(world_size, rank, local_rank, backend="nccl", init_method="env://"):
world_size, rank, local_rank, backend="nccl", init_method="env://"
):
"""Perform env setup and initialization for distributed training""" """Perform env setup and initialization for distributed training"""
if init_method == "env://": if init_method == "env://":
_set_env_vars(world_size, rank, local_rank) _set_env_vars(world_size, rank, local_rank)
......
import math import math
from typing import Optional
from itertools import permutations from itertools import permutations
from typing import Optional
import torch import torch
def sdr( def sdr(
estimate: torch.Tensor, estimate: torch.Tensor, reference: torch.Tensor, mask: Optional[torch.Tensor] = None, epsilon: float = 1e-8
reference: torch.Tensor,
mask: Optional[torch.Tensor] = None,
epsilon: float = 1e-8
) -> torch.Tensor: ) -> torch.Tensor:
"""Computes source-to-distortion ratio. """Computes source-to-distortion ratio.
...@@ -86,11 +83,11 @@ class PIT(torch.nn.Module): ...@@ -86,11 +83,11 @@ class PIT(torch.nn.Module):
self.utility_func = utility_func self.utility_func = utility_func
def forward( def forward(
self, self,
estimate: torch.Tensor, estimate: torch.Tensor,
reference: torch.Tensor, reference: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
epsilon: float = 1e-8 epsilon: float = 1e-8,
) -> torch.Tensor: ) -> torch.Tensor:
"""Compute utterance-level PIT Loss """Compute utterance-level PIT Loss
...@@ -112,9 +109,7 @@ class PIT(torch.nn.Module): ...@@ -112,9 +109,7 @@ class PIT(torch.nn.Module):
batch_size, num_speakers = reference.shape[:2] batch_size, num_speakers = reference.shape[:2]
num_permute = math.factorial(num_speakers) num_permute = math.factorial(num_speakers)
util_mat = torch.zeros( util_mat = torch.zeros(batch_size, num_permute, dtype=estimate.dtype, device=estimate.device)
batch_size, num_permute, dtype=estimate.dtype, device=estimate.device
)
for i, idx in enumerate(permutations(range(num_speakers))): for i, idx in enumerate(permutations(range(num_speakers))):
util = self.utility_func(estimate, reference[:, idx, :], mask=mask, epsilon=epsilon) util = self.utility_func(estimate, reference[:, idx, :], mask=mask, epsilon=epsilon)
util_mat[:, i] = util.mean(dim=1) # take the average over speaker dimension util_mat[:, i] = util.mean(dim=1) # take the average over speaker dimension
...@@ -125,10 +120,8 @@ _sdr_pit = PIT(utility_func=sdr) ...@@ -125,10 +120,8 @@ _sdr_pit = PIT(utility_func=sdr)
def sdr_pit( def sdr_pit(
estimate: torch.Tensor, estimate: torch.Tensor, reference: torch.Tensor, mask: Optional[torch.Tensor] = None, epsilon: float = 1e-8
reference: torch.Tensor, ):
mask: Optional[torch.Tensor] = None,
epsilon: float = 1e-8):
"""Computes scale-invariant source-to-distortion ratio. """Computes scale-invariant source-to-distortion ratio.
1. adjust both estimate and reference to have 0-mean 1. adjust both estimate and reference to have 0-mean
...@@ -164,11 +157,11 @@ def sdr_pit( ...@@ -164,11 +157,11 @@ def sdr_pit(
def sdri( def sdri(
estimate: torch.Tensor, estimate: torch.Tensor,
reference: torch.Tensor, reference: torch.Tensor,
mix: torch.Tensor, mix: torch.Tensor,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
epsilon: float = 1e-8, epsilon: float = 1e-8,
) -> torch.Tensor: ) -> torch.Tensor:
"""Compute the improvement of SDR (SDRi). """Compute the improvement of SDR (SDRi).
......
...@@ -93,9 +93,7 @@ class ASRTest(unittest.TestCase): ...@@ -93,9 +93,7 @@ class ASRTest(unittest.TestCase):
def test_transcribe_file(self): def test_transcribe_file(self):
task, generator, models, sp, tgt_dict = setup_asr(self.args, self.logger) task, generator, models, sp, tgt_dict = setup_asr(self.args, self.logger)
_, transcription = transcribe_file( _, transcription = transcribe_file(self.args, task, generator, models, sp, tgt_dict)
self.args, task, generator, models, sp, tgt_dict
)
expected_transcription = [["THE QUICK BROWN FOX JUMPS OVER THE LAZY DOG"]] expected_transcription = [["THE QUICK BROWN FOX JUMPS OVER THE LAZY DOG"]]
self.assertEqual(transcription, expected_transcription, msg=str(transcription)) self.assertEqual(transcription, expected_transcription, msg=str(transcription))
......
...@@ -32,9 +32,9 @@ print(torchaudio.__version__) ...@@ -32,9 +32,9 @@ print(torchaudio.__version__)
import math import math
import os import os
import requests
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import requests
from IPython.display import Audio, display from IPython.display import Audio, display
...@@ -164,7 +164,7 @@ def get_rir_sample(*, resample=None, processed=False): ...@@ -164,7 +164,7 @@ def get_rir_sample(*, resample=None, processed=False):
rir_raw, sample_rate = _get_sample(SAMPLE_RIR_PATH, resample=resample) rir_raw, sample_rate = _get_sample(SAMPLE_RIR_PATH, resample=resample)
if not processed: if not processed:
return rir_raw, sample_rate return rir_raw, sample_rate
rir = rir_raw[:, int(sample_rate * 1.01): int(sample_rate * 1.3)] rir = rir_raw[:, int(sample_rate * 1.01) : int(sample_rate * 1.3)]
rir = rir / torch.norm(rir, p=2) rir = rir / torch.norm(rir, p=2)
rir = torch.flip(rir, [1]) rir = torch.flip(rir, [1])
return rir, sample_rate return rir, sample_rate
...@@ -225,9 +225,7 @@ effects = [ ...@@ -225,9 +225,7 @@ effects = [
] ]
# Apply effects # Apply effects
waveform2, sample_rate2 = torchaudio.sox_effects.apply_effects_tensor( waveform2, sample_rate2 = torchaudio.sox_effects.apply_effects_tensor(waveform1, sample_rate1, effects)
waveform1, sample_rate1, effects
)
print_stats(waveform1, sample_rate=sample_rate1, src="Original") print_stats(waveform1, sample_rate=sample_rate1, src="Original")
print_stats(waveform2, sample_rate=sample_rate2, src="Effects Applied") print_stats(waveform2, sample_rate=sample_rate2, src="Effects Applied")
...@@ -291,7 +289,7 @@ play_audio(rir_raw, sample_rate) ...@@ -291,7 +289,7 @@ play_audio(rir_raw, sample_rate)
# the signal power, then flip along the time axis. # the signal power, then flip along the time axis.
# #
rir = rir_raw[:, int(sample_rate * 1.01): int(sample_rate * 1.3)] rir = rir_raw[:, int(sample_rate * 1.01) : int(sample_rate * 1.3)]
rir = rir / torch.norm(rir, p=2) rir = rir / torch.norm(rir, p=2)
rir = torch.flip(rir, [1]) rir = torch.flip(rir, [1])
......
...@@ -33,10 +33,10 @@ print(torchaudio.__version__) ...@@ -33,10 +33,10 @@ print(torchaudio.__version__)
# ------------------------------------------------------------------------------- # -------------------------------------------------------------------------------
import os import os
import requests
import librosa import librosa
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import requests
_SAMPLE_DIR = "_assets" _SAMPLE_DIR = "_assets"
...@@ -125,17 +125,13 @@ stretch = T.TimeStretch() ...@@ -125,17 +125,13 @@ stretch = T.TimeStretch()
rate = 1.2 rate = 1.2
spec_ = stretch(spec, rate) spec_ = stretch(spec, rate)
plot_spectrogram( plot_spectrogram(torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304)
torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304
)
plot_spectrogram(torch.abs(spec[0]), title="Original", aspect="equal", xmax=304) plot_spectrogram(torch.abs(spec[0]), title="Original", aspect="equal", xmax=304)
rate = 0.9 rate = 0.9
spec_ = stretch(spec, rate) spec_ = stretch(spec, rate)
plot_spectrogram( plot_spectrogram(torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304)
torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304
)
###################################################################### ######################################################################
# TimeMasking # TimeMasking
......
...@@ -51,10 +51,10 @@ print(torchaudio.__version__) ...@@ -51,10 +51,10 @@ print(torchaudio.__version__)
# ------------------------------------------------------------------------------- # -------------------------------------------------------------------------------
import os import os
import requests
import librosa import librosa
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import requests
from IPython.display import Audio, display from IPython.display import Audio, display
...@@ -199,9 +199,7 @@ def plot_kaldi_pitch(waveform, sample_rate, pitch, nfcc): ...@@ -199,9 +199,7 @@ def plot_kaldi_pitch(waveform, sample_rate, pitch, nfcc):
axis2 = axis.twinx() axis2 = axis.twinx()
time_axis = torch.linspace(0, end_time, nfcc.shape[1]) time_axis = torch.linspace(0, end_time, nfcc.shape[1])
ln2 = axis2.plot( ln2 = axis2.plot(time_axis, nfcc[0], linewidth=2, label="NFCC", color="blue", linestyle="--")
time_axis, nfcc[0], linewidth=2, label="NFCC", color="blue", linestyle="--"
)
lns = ln1 + ln2 lns = ln1 + ln2
labels = [l.get_label() for l in lns] labels = [l.get_label() for l in lns]
......
...@@ -32,13 +32,13 @@ print(torchaudio.__version__) ...@@ -32,13 +32,13 @@ print(torchaudio.__version__)
import io import io
import os import os
import requests
import tarfile import tarfile
import boto3 import boto3
import matplotlib.pyplot as plt
import requests
from botocore import UNSIGNED from botocore import UNSIGNED
from botocore.config import Config from botocore.config import Config
import matplotlib.pyplot as plt
from IPython.display import Audio, display from IPython.display import Audio, display
...@@ -348,14 +348,12 @@ frame_offset, num_frames = 16000, 16000 # Fetch and decode the 1 - 2 seconds ...@@ -348,14 +348,12 @@ frame_offset, num_frames = 16000, 16000 # Fetch and decode the 1 - 2 seconds
print("Fetching all the data...") print("Fetching all the data...")
with requests.get(SAMPLE_WAV_SPEECH_URL, stream=True) as response: with requests.get(SAMPLE_WAV_SPEECH_URL, stream=True) as response:
waveform1, sample_rate1 = torchaudio.load(response.raw) waveform1, sample_rate1 = torchaudio.load(response.raw)
waveform1 = waveform1[:, frame_offset: frame_offset + num_frames] waveform1 = waveform1[:, frame_offset : frame_offset + num_frames]
print(f" - Fetched {response.raw.tell()} bytes") print(f" - Fetched {response.raw.tell()} bytes")
print("Fetching until the requested frames are available...") print("Fetching until the requested frames are available...")
with requests.get(SAMPLE_WAV_SPEECH_URL, stream=True) as response: with requests.get(SAMPLE_WAV_SPEECH_URL, stream=True) as response:
waveform2, sample_rate2 = torchaudio.load( waveform2, sample_rate2 = torchaudio.load(response.raw, frame_offset=frame_offset, num_frames=num_frames)
response.raw, frame_offset=frame_offset, num_frames=num_frames
)
print(f" - Fetched {response.raw.tell()} bytes") print(f" - Fetched {response.raw.tell()} bytes")
print("Checking the resulting waveform ... ", end="") print("Checking the resulting waveform ... ", end="")
......
...@@ -38,8 +38,8 @@ import time ...@@ -38,8 +38,8 @@ import time
import librosa import librosa
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from IPython.display import Audio, display
import pandas as pd import pandas as pd
from IPython.display import Audio, display
DEFAULT_OFFSET = 201 DEFAULT_OFFSET = 201
...@@ -56,9 +56,7 @@ def _get_log_freq(sample_rate, max_sweep_rate, offset): ...@@ -56,9 +56,7 @@ def _get_log_freq(sample_rate, max_sweep_rate, offset):
""" """
start, stop = math.log(offset), math.log(offset + max_sweep_rate // 2) start, stop = math.log(offset), math.log(offset + max_sweep_rate // 2)
return ( return torch.exp(torch.linspace(start, stop, sample_rate, dtype=torch.double)) - offset
torch.exp(torch.linspace(start, stop, sample_rate, dtype=torch.double)) - offset
)
def _get_inverse_log_freq(freq, sample_rate, offset): def _get_inverse_log_freq(freq, sample_rate, offset):
...@@ -192,9 +190,7 @@ def benchmark_resample( ...@@ -192,9 +190,7 @@ def benchmark_resample(
waveform_np = waveform.squeeze().numpy() waveform_np = waveform.squeeze().numpy()
begin = time.time() begin = time.time()
for _ in range(iters): for _ in range(iters):
librosa.resample( librosa.resample(waveform_np, sample_rate, resample_rate, res_type=librosa_type)
waveform_np, sample_rate, resample_rate, res_type=librosa_type
)
elapsed = time.time() - begin elapsed = time.time() - begin
return elapsed / iters return elapsed / iters
...@@ -264,14 +260,10 @@ play_audio(waveform, sample_rate) ...@@ -264,14 +260,10 @@ play_audio(waveform, sample_rate)
sample_rate = 48000 sample_rate = 48000
resample_rate = 32000 resample_rate = 32000
resampled_waveform = F.resample( resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=6)
waveform, sample_rate, resample_rate, lowpass_filter_width=6
)
plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=6") plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=6")
resampled_waveform = F.resample( resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=128)
waveform, sample_rate, resample_rate, lowpass_filter_width=128
)
plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=128") plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=128")
...@@ -315,14 +307,10 @@ plot_sweep(resampled_waveform, resample_rate, title="rolloff=0.8") ...@@ -315,14 +307,10 @@ plot_sweep(resampled_waveform, resample_rate, title="rolloff=0.8")
sample_rate = 48000 sample_rate = 48000
resample_rate = 32000 resample_rate = 32000
resampled_waveform = F.resample( resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="sinc_interpolation")
waveform, sample_rate, resample_rate, resampling_method="sinc_interpolation"
)
plot_sweep(resampled_waveform, resample_rate, title="Hann Window Default") plot_sweep(resampled_waveform, resample_rate, title="Hann Window Default")
resampled_waveform = F.resample( resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="kaiser_window")
waveform, sample_rate, resample_rate, resampling_method="kaiser_window"
)
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Default") plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Default")
...@@ -351,13 +339,9 @@ resampled_waveform = F.resample( ...@@ -351,13 +339,9 @@ resampled_waveform = F.resample(
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Best (torchaudio)") plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Best (torchaudio)")
librosa_resampled_waveform = torch.from_numpy( librosa_resampled_waveform = torch.from_numpy(
librosa.resample( librosa.resample(waveform.squeeze().numpy(), sample_rate, resample_rate, res_type="kaiser_best")
waveform.squeeze().numpy(), sample_rate, resample_rate, res_type="kaiser_best"
)
).unsqueeze(0) ).unsqueeze(0)
plot_sweep( plot_sweep(librosa_resampled_waveform, resample_rate, title="Kaiser Window Best (librosa)")
librosa_resampled_waveform, resample_rate, title="Kaiser Window Best (librosa)"
)
mse = torch.square(resampled_waveform - librosa_resampled_waveform).mean().item() mse = torch.square(resampled_waveform - librosa_resampled_waveform).mean().item()
print("torchaudio and librosa kaiser best MSE:", mse) print("torchaudio and librosa kaiser best MSE:", mse)
...@@ -372,18 +356,12 @@ resampled_waveform = F.resample( ...@@ -372,18 +356,12 @@ resampled_waveform = F.resample(
resampling_method="kaiser_window", resampling_method="kaiser_window",
beta=8.555504641634386, beta=8.555504641634386,
) )
plot_specgram( plot_specgram(resampled_waveform, resample_rate, title="Kaiser Window Fast (torchaudio)")
resampled_waveform, resample_rate, title="Kaiser Window Fast (torchaudio)"
)
librosa_resampled_waveform = torch.from_numpy( librosa_resampled_waveform = torch.from_numpy(
librosa.resample( librosa.resample(waveform.squeeze().numpy(), sample_rate, resample_rate, res_type="kaiser_fast")
waveform.squeeze().numpy(), sample_rate, resample_rate, res_type="kaiser_fast"
)
).unsqueeze(0) ).unsqueeze(0)
plot_sweep( plot_sweep(librosa_resampled_waveform, resample_rate, title="Kaiser Window Fast (librosa)")
librosa_resampled_waveform, resample_rate, title="Kaiser Window Fast (librosa)"
)
mse = torch.square(resampled_waveform - librosa_resampled_waveform).mean().item() mse = torch.square(resampled_waveform - librosa_resampled_waveform).mean().item()
print("torchaudio and librosa kaiser fast MSE:", mse) print("torchaudio and librosa kaiser fast MSE:", mse)
...@@ -426,29 +404,19 @@ for label in configs: ...@@ -426,29 +404,19 @@ for label in configs:
waveform = get_sine_sweep(sample_rate) waveform = get_sine_sweep(sample_rate)
# sinc 64 zero-crossings # sinc 64 zero-crossings
f_time = benchmark_resample( f_time = benchmark_resample("functional", waveform, sample_rate, resample_rate, lowpass_filter_width=64)
"functional", waveform, sample_rate, resample_rate, lowpass_filter_width=64 t_time = benchmark_resample("transforms", waveform, sample_rate, resample_rate, lowpass_filter_width=64)
)
t_time = benchmark_resample(
"transforms", waveform, sample_rate, resample_rate, lowpass_filter_width=64
)
times.append([None, 1000 * f_time, 1000 * t_time]) times.append([None, 1000 * f_time, 1000 * t_time])
rows.append("sinc (width 64)") rows.append("sinc (width 64)")
# sinc 6 zero-crossings # sinc 6 zero-crossings
f_time = benchmark_resample( f_time = benchmark_resample("functional", waveform, sample_rate, resample_rate, lowpass_filter_width=16)
"functional", waveform, sample_rate, resample_rate, lowpass_filter_width=16 t_time = benchmark_resample("transforms", waveform, sample_rate, resample_rate, lowpass_filter_width=16)
)
t_time = benchmark_resample(
"transforms", waveform, sample_rate, resample_rate, lowpass_filter_width=16
)
times.append([None, 1000 * f_time, 1000 * t_time]) times.append([None, 1000 * f_time, 1000 * t_time])
rows.append("sinc (width 16)") rows.append("sinc (width 16)")
# kaiser best # kaiser best
lib_time = benchmark_resample( lib_time = benchmark_resample("librosa", waveform, sample_rate, resample_rate, librosa_type="kaiser_best")
"librosa", waveform, sample_rate, resample_rate, librosa_type="kaiser_best"
)
f_time = benchmark_resample( f_time = benchmark_resample(
"functional", "functional",
waveform, waveform,
...@@ -473,9 +441,7 @@ for label in configs: ...@@ -473,9 +441,7 @@ for label in configs:
rows.append("kaiser_best") rows.append("kaiser_best")
# kaiser fast # kaiser fast
lib_time = benchmark_resample( lib_time = benchmark_resample("librosa", waveform, sample_rate, resample_rate, librosa_type="kaiser_fast")
"librosa", waveform, sample_rate, resample_rate, librosa_type="kaiser_fast"
)
f_time = benchmark_resample( f_time = benchmark_resample(
"functional", "functional",
waveform, waveform,
...@@ -499,8 +465,6 @@ for label in configs: ...@@ -499,8 +465,6 @@ for label in configs:
times.append([1000 * lib_time, 1000 * f_time, 1000 * t_time]) times.append([1000 * lib_time, 1000 * f_time, 1000 * t_time])
rows.append("kaiser_fast") rows.append("kaiser_fast")
df = pd.DataFrame( df = pd.DataFrame(times, columns=["librosa", "functional", "transforms"], index=rows)
times, columns=["librosa", "functional", "transforms"], index=rows
)
df.columns = pd.MultiIndex.from_product([[f"{label} time (ms)"], df.columns]) df.columns = pd.MultiIndex.from_product([[f"{label} time (ms)"], df.columns])
display(df.round(2)) display(df.round(2))
...@@ -40,12 +40,12 @@ Recognition <https://arxiv.org/abs/2007.09127>`__. ...@@ -40,12 +40,12 @@ Recognition <https://arxiv.org/abs/2007.09127>`__.
import os import os
from dataclasses import dataclass from dataclasses import dataclass
import torch import IPython
import torchaudio
import requests
import matplotlib import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import IPython import requests
import torch
import torchaudio
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8] matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
...@@ -325,7 +325,7 @@ def plot_trellis_with_segments(trellis, segments, transcript): ...@@ -325,7 +325,7 @@ def plot_trellis_with_segments(trellis, segments, transcript):
trellis_with_path = trellis.clone() trellis_with_path = trellis.clone()
for i, seg in enumerate(segments): for i, seg in enumerate(segments):
if seg.label != "|": if seg.label != "|":
trellis_with_path[seg.start + 1: seg.end + 1, i + 1] = float("nan") trellis_with_path[seg.start + 1 : seg.end + 1, i + 1] = float("nan")
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5)) fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5))
ax1.set_title("Path, label and probability for each label") ax1.set_title("Path, label and probability for each label")
...@@ -383,12 +383,8 @@ def merge_words(segments, separator="|"): ...@@ -383,12 +383,8 @@ def merge_words(segments, separator="|"):
if i1 != i2: if i1 != i2:
segs = segments[i1:i2] segs = segments[i1:i2]
word = "".join([seg.label for seg in segs]) word = "".join([seg.label for seg in segs])
score = sum(seg.score * seg.length for seg in segs) / sum( score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
seg.length for seg in segs words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score))
)
words.append(
Segment(word, segments[i1].start, segments[i2 - 1].end, score)
)
i1 = i2 + 1 i1 = i2 + 1
i2 = i1 i2 = i1
else: else:
...@@ -408,7 +404,7 @@ def plot_alignments(trellis, segments, word_segments, waveform): ...@@ -408,7 +404,7 @@ def plot_alignments(trellis, segments, word_segments, waveform):
trellis_with_path = trellis.clone() trellis_with_path = trellis.clone()
for i, seg in enumerate(segments): for i, seg in enumerate(segments):
if seg.label != "|": if seg.label != "|":
trellis_with_path[seg.start + 1: seg.end + 1, i + 1] = float("nan") trellis_with_path[seg.start + 1 : seg.end + 1, i + 1] = float("nan")
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5)) fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5))
...@@ -464,9 +460,7 @@ def display_segment(i): ...@@ -464,9 +460,7 @@ def display_segment(i):
x1 = int(ratio * word.end) x1 = int(ratio * word.end)
filename = f"_assets/{i}_{word.label}.wav" filename = f"_assets/{i}_{word.label}.wav"
torchaudio.save(filename, waveform[:, x0:x1], bundle.sample_rate) torchaudio.save(filename, waveform[:, x0:x1], bundle.sample_rate)
print( print(f"{word.label} ({word.score:.2f}): {x0 / bundle.sample_rate:.3f} - {x1 / bundle.sample_rate:.3f} sec")
f"{word.label} ({word.score:.2f}): {x0 / bundle.sample_rate:.3f} - {x1 / bundle.sample_rate:.3f} sec"
)
return IPython.display.Audio(filename) return IPython.display.Audio(filename)
......
...@@ -44,10 +44,11 @@ MVDR with torchaudio ...@@ -44,10 +44,11 @@ MVDR with torchaudio
# #
import os import os
import IPython.display as ipd
import requests import requests
import torch import torch
import torchaudio import torchaudio
import IPython.display as ipd
torch.random.manual_seed(0) torch.random.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
...@@ -167,9 +168,7 @@ for solution in ["ref_channel", "stv_evd", "stv_power"]: ...@@ -167,9 +168,7 @@ for solution in ["ref_channel", "stv_evd", "stv_power"]:
results_single = {} results_single = {}
for solution in ["ref_channel", "stv_evd", "stv_power"]: for solution in ["ref_channel", "stv_evd", "stv_power"]:
mvdr = torchaudio.transforms.MVDR( mvdr = torchaudio.transforms.MVDR(ref_channel=0, solution=solution, multi_mask=False)
ref_channel=0, solution=solution, multi_mask=False
)
stft_est = mvdr(spec_mix, irm_speech[0], irm_noise[0]) stft_est = mvdr(spec_mix, irm_speech[0], irm_noise[0])
est = istft(stft_est, length=mix.shape[-1]) est = istft(stft_est, length=mix.shape[-1])
results_single[solution] = est results_single[solution] = est
...@@ -211,9 +210,7 @@ def si_sdr(estimate, reference, epsilon=1e-8): ...@@ -211,9 +210,7 @@ def si_sdr(estimate, reference, epsilon=1e-8):
# #
for solution in results_single: for solution in results_single:
print( print(solution + ": ", si_sdr(results_single[solution][None, ...], reverb_clean[0:1]))
solution + ": ", si_sdr(results_single[solution][None, ...], reverb_clean[0:1])
)
###################################################################### ######################################################################
# Multi-channel mask results # Multi-channel mask results
...@@ -221,9 +218,7 @@ for solution in results_single: ...@@ -221,9 +218,7 @@ for solution in results_single:
# #
for solution in results_multi: for solution in results_multi:
print( print(solution + ": ", si_sdr(results_multi[solution][None, ...], reverb_clean[0:1]))
solution + ": ", si_sdr(results_multi[solution][None, ...], reverb_clean[0:1])
)
###################################################################### ######################################################################
# Original audio # Original audio
......
...@@ -41,12 +41,12 @@ pre-trained models from wav2vec 2.0 ...@@ -41,12 +41,12 @@ pre-trained models from wav2vec 2.0
import os import os
import torch import IPython
import torchaudio
import requests
import matplotlib import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import IPython import requests
import torch
import torchaudio
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8] matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
......
...@@ -7,6 +7,10 @@ Text-to-Speech with Tacotron2 ...@@ -7,6 +7,10 @@ Text-to-Speech with Tacotron2
""" """
import IPython
import matplotlib
import matplotlib.pyplot as plt
###################################################################### ######################################################################
# Overview # Overview
# -------- # --------
...@@ -58,10 +62,6 @@ Text-to-Speech with Tacotron2 ...@@ -58,10 +62,6 @@ Text-to-Speech with Tacotron2
import torch import torch
import torchaudio import torchaudio
import matplotlib
import matplotlib.pyplot as plt
import IPython
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8] matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
...@@ -271,9 +271,7 @@ fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9)) ...@@ -271,9 +271,7 @@ fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9))
ax1.imshow(spec[0].cpu().detach()) ax1.imshow(spec[0].cpu().detach())
ax2.plot(waveforms[0].cpu().detach()) ax2.plot(waveforms[0].cpu().detach())
torchaudio.save( torchaudio.save("_assets/output_wavernn.wav", waveforms[0:1].cpu(), sample_rate=vocoder.sample_rate)
"_assets/output_wavernn.wav", waveforms[0:1].cpu(), sample_rate=vocoder.sample_rate
)
IPython.display.Audio("_assets/output_wavernn.wav") IPython.display.Audio("_assets/output_wavernn.wav")
...@@ -332,9 +330,7 @@ checkpoint = torch.hub.load_state_dict_from_url( ...@@ -332,9 +330,7 @@ checkpoint = torch.hub.load_state_dict_from_url(
progress=False, progress=False,
map_location=device, map_location=device,
) )
state_dict = { state_dict = {key.replace("module.", ""): value for key, value in checkpoint["state_dict"].items()}
key.replace("module.", ""): value for key, value in checkpoint["state_dict"].items()
}
waveglow.load_state_dict(state_dict) waveglow.load_state_dict(state_dict)
waveglow = waveglow.remove_weightnorm(waveglow) waveglow = waveglow.remove_weightnorm(waveglow)
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment