Unverified Commit 2e58f18a authored by yangarbiter's avatar yangarbiter Committed by GitHub
Browse files

Refactor coding style for wavernn example (#1663)

parent 077a5f4a
import os
import random import random
import torch import torch
import torchaudio
from torch.utils.data.dataset import random_split from torch.utils.data.dataset import random_split
from torchaudio.datasets import LJSPEECH, LIBRITTS from torchaudio.datasets import LJSPEECH, LIBRITTS
from torchaudio.transforms import MuLawEncoding from torchaudio.transforms import MuLawEncoding
......
import argparse import argparse
import torch import torch
import torch.nn.functional as F
import torchaudio import torchaudio
from torchaudio.transforms import MelSpectrogram from torchaudio.transforms import MelSpectrogram
from torchaudio.models import wavernn from torchaudio.models import wavernn
......
import argparse import argparse
import logging import logging
import os import os
import signal
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
from time import time from time import time
...@@ -9,7 +8,6 @@ from typing import List ...@@ -9,7 +8,6 @@ from typing import List
import torch import torch
import torchaudio import torchaudio
from torch import nn as nn
from torch.optim import Adam from torch.optim import Adam
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchaudio.datasets.utils import bg_iterator from torchaudio.datasets.utils import bg_iterator
......
# ***************************************************************************** # *****************************************************************************
# Copyright (c) 2019 fatchord (https://github.com/fatchord) # Copyright (c) 2019 fatchord (https://github.com/fatchord)
# #
# Permission is hereby granted, free of charge, to any person obtaining a copy # Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal # of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights # in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is # copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions: # furnished to do so, subject to the following conditions:
# #
# The above copyright notice and this permission notice shall be included in all # The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software. # copies or substantial portions of the Software.
# #
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
...@@ -34,6 +34,7 @@ from processing import ( ...@@ -34,6 +34,7 @@ from processing import (
bits_to_normalized_waveform, bits_to_normalized_waveform,
) )
class WaveRNNInferenceWrapper(torch.nn.Module): class WaveRNNInferenceWrapper(torch.nn.Module):
def __init__(self, wavernn: WaveRNN): def __init__(self, wavernn: WaveRNN):
...@@ -157,7 +158,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module): ...@@ -157,7 +158,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
padded[:, :, :t] = x padded[:, :, :t] = x
else: else:
raise ValueError(f"Unexpected side: '{side}'. " raise ValueError(f"Unexpected side: '{side}'. "
f"Valid choices are 'both', 'before' and 'after'.") f"Valid choices are 'both', 'before' and 'after'.")
return padded return padded
def forward(self, def forward(self,
...@@ -242,7 +243,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module): ...@@ -242,7 +243,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
output.append(x.squeeze(-1)) output.append(x.squeeze(-1))
else: else:
raise ValueError(f"Unexpected loss_name: '{loss_name}'. " raise ValueError(f"Unexpected loss_name: '{loss_name}'. "
f"Valid choices are 'crossentropy'.") f"Valid choices are 'crossentropy'.")
output = torch.stack(output).transpose(0, 1).cpu() output = torch.stack(output).transpose(0, 1).cpu()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment